Caffe2 - C++ API
A deep learning, cross platform ML framework
transpose_op_cudnn.cc
1 #include "caffe2/operators/transpose_op.h"
2 
3 #include <algorithm>
4 #include <limits>
5 #include <vector>
6 
7 #include "caffe2/core/context_gpu.h"
8 #include "caffe2/core/cudnn_wrappers.h"
9 #include "caffe2/core/types.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 namespace {
15 
16 class CuDNNTransposeOp final : public Operator<CUDAContext> {
17  public:
18  USE_OPERATOR_FUNCTIONS(CUDAContext);
19 
20  template <class... Args>
21  explicit CuDNNTransposeOp(Args&&... args)
22  : Operator<CUDAContext>(std::forward<Args>(args)...),
23  cudnn_wrapper_(&context_),
24  axes_(OperatorBase::GetRepeatedArgument<int>("axes")) {
25  // Checks the legality of axes_: it should be from 0 to axes_.size().
26  std::vector<int> axes_sorted(axes_);
27  std::sort(axes_sorted.begin(), axes_sorted.end());
28  for (std::size_t i = 0; i < axes_sorted.size(); ++i) {
29  if (axes_sorted[i] != i) {
30  CAFFE_THROW("Axes should be a permutation of 0 to ndim.");
31  }
32  }
33 
34  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&X_desc_));
35  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&Y_desc_));
36  }
37 
38  ~CuDNNTransposeOp() override {
39  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(X_desc_));
40  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(Y_desc_));
41  }
42 
43  bool RunOnDevice() override {
44  return DispatchHelper<TensorTypes<float, int>>::call(this, Input(0));
45  }
46 
47  template <typename T>
48  bool DoRunWithType() {
49  const auto& X = Input(0);
50  const int ndim = X.dim();
51  if (axes_.empty()) {
52  axes_.resize(ndim);
53  std::iota(axes_.rbegin(), axes_.rend(), 0);
54  } else {
55  CAFFE_ENFORCE_EQ(axes_.size(), ndim);
56  }
57  std::vector<std::int64_t> X_dims = X.sizes().vec();
58  std::vector<std::int64_t> Y_dims(ndim);
59  for (int i = 0; i < ndim; ++i) {
60  Y_dims[i] = X_dims[axes_[i]];
61  }
62  auto* Y = Output(0, Y_dims, at::dtype<T>());
63  const T* X_data = X.template data<T>();
64  T* Y_data = Y->template mutable_data<T>();
65  if (X.numel() == 0) {
66  return true;
67  }
68  if (ndim < 3 || ndim > CUDNN_DIM_MAX ||
69  X.numel() > std::numeric_limits<std::int32_t>::max()) {
70  math::Transpose<std::int64_t, T, CUDAContext>(
71  ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_);
72  return true;
73  }
74  if (X_dims != cached_X_dims_) {
75  SetTensorDescriptor(cudnnTypeWrapper<T>::type, X_dims, Y_dims);
76  cached_X_dims_ = X_dims;
77  }
78  CUDNN_ENFORCE(cudnnTransformTensor(
79  cudnn_wrapper_.inline_cudnn_handle(),
80  cudnnTypeWrapper<T>::kOne(),
81  X_desc_,
82  X_data,
83  cudnnTypeWrapper<T>::kZero(),
84  Y_desc_,
85  Y_data));
86  return true;
87  }
88 
89  private:
90  void SetTensorDescriptor(
91  const cudnnDataType_t data_type,
92  const std::vector<std::int64_t>& X_dims,
93  const std::vector<std::int64_t>& Y_dims) {
94  const int ndim = X_dims.size();
95  std::vector<int> dims(Y_dims.cbegin(), Y_dims.cend());
96  std::vector<int> X_strides(ndim);
97  std::vector<int> X_buff(ndim);
98  std::vector<int> Y_strides(ndim);
99  X_buff.back() = 1;
100  Y_strides.back() = 1;
101  for (int i = ndim - 1; i > 0; --i) {
102  X_buff[i - 1] = X_buff[i] * X_dims[i];
103  Y_strides[i - 1] = Y_strides[i] * Y_dims[i];
104  }
105  for (int i = 0; i < ndim; ++i) {
106  X_strides[i] = X_buff[axes_[i]];
107  }
108  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
109  X_desc_, data_type, ndim, dims.data(), X_strides.data()));
110  CUDNN_ENFORCE(cudnnSetTensorNdDescriptor(
111  Y_desc_, data_type, ndim, dims.data(), Y_strides.data()));
112  }
113 
114  CuDNNWrapper cudnn_wrapper_;
115  cudnnTensorDescriptor_t X_desc_;
116  cudnnTensorDescriptor_t Y_desc_;
117 
118  std::vector<std::int64_t> cached_X_dims_;
119  std::vector<std::int32_t> axes_;
120 };
121 
122 #if !CUDNN_VERSION_MIN(6, 0, 0)
123 
124 // CuDNN 5.1 does not have int support yet.
125 template <>
126 bool CuDNNTransposeOp::DoRunWithType<int>() {
127  const auto& X = Input(0);
128  const int ndim = X.dim();
129  if (axes_.empty()) {
130  axes_.resize(ndim);
131  std::iota(axes_.rbegin(), axes_.rend(), 0);
132  } else {
133  CAFFE_ENFORCE_EQ(axes_.size(), ndim);
134  }
135  std::vector<std::int64_t> X_dims = X.sizes().vec();
136  std::vector<std::int64_t> Y_dims(ndim);
137  for (int i = 0; i < ndim; ++i) {
138  Y_dims[i] = X_dims[axes_[i]];
139  }
140  auto* Y = Output(0, Y_dims, at::dtype<T>());
141  const T* X_data = X.template data<T>();
142  T* Y_data = Y->template mutable_data<T>();
143  math::Transpose<std::int64_t, T, CUDAContext>(
144  ndim, X_dims.data(), axes_.data(), X_data, Y_data, &context_);
145  return true;
146 }
147 
148 #endif // !CUDNN_VERSION_MIN(6, 0, 0)
149 
150 } // namespace
151 
152 REGISTER_CUDNN_OPERATOR(Transpose, CuDNNTransposeOp);
153 
154 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13