Caffe2 - C++ API
A deep learning, cross platform ML framework
order_switch_ops_cudnn.cc
1 #include "caffe2/operators/order_switch_ops.h"
2 
3 #include <algorithm>
4 #include <functional>
5 #include <vector>
6 
7 #include "caffe2/core/context_gpu.h"
8 #include "caffe2/core/cudnn_wrappers.h"
9 #include "caffe2/core/types.h"
10 
11 namespace caffe2 {
12 
13 namespace {
14 
15 class CuDNNOrderSwithOpBase : public Operator<CUDAContext> {
16  public:
17  USE_OPERATOR_FUNCTIONS(CUDAContext);
18 
19  template <class... Args>
20  explicit CuDNNOrderSwithOpBase(Args&&... args)
21  : Operator<CUDAContext>(std::forward<Args>(args)...),
22  cudnn_wrapper_(&context_) {
23  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
24  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
25  }
26 
27  ~CuDNNOrderSwithOpBase() override {
28  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
29  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
30  }
31 
32  protected:
33  // TODO: std::vector<int> -> std::vector<int64_t>
34  void SetTensorDescriptor(
35  const cudnnDataType_t data_type,
36  const StorageOrder order,
37  const std::vector<int>& data_dims,
38  cudnnTensorDescriptor_t data_desc) const {
39  const int ndim = data_dims.size();
40  const int N = data_dims[0];
41  const int C = order == StorageOrder::NCHW ? data_dims[1] : data_dims.back();
42  if (ndim == 3) {
43  const int H = 1;
44  const int W = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
45  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
46  data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
47  } else if (ndim == 4) {
48  const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
49  const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
50  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
51  data_desc, GetCudnnTensorFormat(order), data_type, N, C, H, W));
52  } else {
53  const int H = order == StorageOrder::NCHW ? data_dims[2] : data_dims[1];
54  const int W = order == StorageOrder::NCHW ? data_dims[3] : data_dims[2];
55  const auto l_iter = order == StorageOrder::NCHW ? data_dims.cbegin() + 4
56  : data_dims.cbegin() + 3;
57  const auto r_iter =
58  order == StorageOrder::NCHW ? data_dims.cend() : data_dims.cend() - 1;
59  const int D = std::accumulate(l_iter, r_iter, 1, std::multiplies<int>());
60  const std::array<int, 5> dims = {N, C, H, W, D};
61  const std::array<int, 5> strides = order == StorageOrder::NCHW
62  ? std::array<int, 5>{C * H * W * D, H * W * D, W * D, D, 1}
63  : std::array<int, 5>{C * H * W * D, 1, W * D * C, D * C, C};
64  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
65  data_desc, data_type, 5, dims.data(), strides.data()));
66  }
67  }
68 
69  CuDNNWrapper cudnn_wrapper_;
70  cudnnTensorDescriptor_t X_desc_;
71  cudnnTensorDescriptor_t Y_desc_;
72 
73  std::vector<int> cached_X_dims_;
74 };
75 
76 class CuDNNNHWC2NCHWOp final : public CuDNNOrderSwithOpBase {
77  public:
78  template <class... Args>
79  explicit CuDNNNHWC2NCHWOp(Args&&... args)
80  : CuDNNOrderSwithOpBase(std::forward<Args>(args)...) {}
81 
82  bool RunOnDevice() override {
83  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
84  }
85 
86  template <typename T>
87  bool DoRunWithType() {
88  const auto& X = Input(0);
89 
90  const int ndim = X.dim();
91  const int N = X.dim32(0);
92  const int C = X.dim32(ndim - 1);
93  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
94  std::vector<int> Y_dims(ndim);
95  Y_dims[0] = N;
96  Y_dims[1] = C;
97  std::copy(X_dims.cbegin() + 1, X_dims.cend() - 1, Y_dims.begin() + 2);
98  std::vector<int64_t> Y_dims_64;
99  std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64));
100  auto* Y = Output(0, Y_dims_64, at::dtype<T>());
101  if (cached_X_dims_ != X_dims) {
102  cached_X_dims_ = X_dims;
103  SetTensorDescriptor(
104  cudnnTypeWrapper<T>::type, StorageOrder::NHWC, X_dims, X_desc_);
105  SetTensorDescriptor(
106  cudnnTypeWrapper<T>::type, StorageOrder::NCHW, Y_dims, Y_desc_);
107  }
108  CUDNN_ENFORCE(cudnnTransformTensor(
109  cudnn_wrapper_.inline_cudnn_handle(),
110  cudnnTypeWrapper<T>::kOne(),
111  X_desc_,
112  X.template data<T>(),
113  cudnnTypeWrapper<T>::kZero(),
114  Y_desc_,
115  Y->template mutable_data<T>()));
116  return true;
117  }
118 };
119 
120 class CuDNNNCHW2NHWCOp final : public CuDNNOrderSwithOpBase {
121  public:
122  template <class... Args>
123  explicit CuDNNNCHW2NHWCOp(Args&&... args)
124  : CuDNNOrderSwithOpBase(std::forward<Args>(args)...) {}
125 
126  bool RunOnDevice() override {
127  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
128  }
129 
130  template <typename T>
131  bool DoRunWithType() {
132  const auto& X = Input(0);
133 
134  const int ndim = X.dim();
135  const int N = X.dim32(0);
136  const int C = X.dim32(1);
137  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
138  std::vector<int> Y_dims(ndim);
139  Y_dims[0] = N;
140  Y_dims[ndim - 1] = C;
141  std::copy(X_dims.cbegin() + 2, X_dims.cend(), Y_dims.begin() + 1);
142  std::vector<int64_t> Y_dims_64;
143  std::copy(Y_dims.cbegin(), Y_dims.cend(), std::back_inserter(Y_dims_64));
144  auto* Y = Output(0, Y_dims_64, at::dtype<T>());
145  if (cached_X_dims_ != X_dims) {
146  cached_X_dims_ = X_dims;
147  SetTensorDescriptor(
148  cudnnTypeWrapper<T>::type, StorageOrder::NCHW, X_dims, X_desc_);
149  SetTensorDescriptor(
150  cudnnTypeWrapper<T>::type, StorageOrder::NHWC, Y_dims, Y_desc_);
151  }
152  CUDNN_ENFORCE(cudnnTransformTensor(
153  cudnn_wrapper_.inline_cudnn_handle(),
154  cudnnTypeWrapper<T>::kOne(),
155  X_desc_,
156  X.template data<T>(),
157  cudnnTypeWrapper<T>::kZero(),
158  Y_desc_,
159  Y->template mutable_data<T>()));
160  return true;
161  }
162 };
163 
164 } // namespace
165 
166 REGISTER_CUDNN_OPERATOR(NHWC2NCHW, CuDNNNHWC2NCHWOp);
167 REGISTER_CUDNN_OPERATOR(NCHW2NHWC, CuDNNNCHW2NHWCOp);
168 
169 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:192
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:70