Caffe2 - C++ API
A deep learning, cross platform ML framework
utility_ops_cudnn.cc
1 #include "caffe2/operators/utility_ops.h"
2 
3 #include <type_traits>
4 
5 #include "caffe2/core/context_gpu.h"
6 #include "caffe2/core/cudnn_wrappers.h"
7 #include "caffe2/utils/conversions.h"
8 
9 namespace caffe2 {
10 
11 class CuDNNWeightedSumOp : public Operator<CUDAContext> {
12  public:
13  USE_OPERATOR_FUNCTIONS(CUDAContext);
14 
15  template <class... Args>
16  explicit CuDNNWeightedSumOp(Args&&... args)
17  : Operator<CUDAContext>(std::forward<Args>(args)...),
18  cudnn_wrapper_(&context_) {
19  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
20  CUDNN_ENFORCE(cudnnCreateOpTensorDescriptor(&add_desc_));
21  // Both float and at::Half require opTensorCompType to be CUDNN_DATA_FLOAT.
22  CUDNN_ENFORCE(cudnnSetOpTensorDescriptor(
23  add_desc_, CUDNN_OP_TENSOR_ADD, CUDNN_DATA_FLOAT, CUDNN_PROPAGATE_NAN));
24  }
25 
26  ~CuDNNWeightedSumOp() override {
27  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
28  CUDNN_ENFORCE(cudnnDestroyOpTensorDescriptor(add_desc_));
29  }
30 
31  bool RunOnDevice() override {
32  return DispatchHelper<TensorTypes<float, at::Half>>::call(this, Input(0));
33  }
34 
35  template <typename T>
36  bool DoRunWithType() {
37  if (std::is_same<T, at::Half>::value) {
38  LOG(WARNING)
39  << "CuDNN only support same type for data and weight, "
40  "so the weight will be cast to at::Half when data type is Half.";
41  }
42  const int num_inputs = InputSize();
43  CAFFE_ENFORCE_EQ(num_inputs % 2, 0);
44  const auto& X0 = Input(0);
45  const auto& weight0 = Input(1);
46  CAFFE_ENFORCE_GT(X0.numel(), 0);
47  CAFFE_ENFORCE_EQ(weight0.numel(), 1);
48  const int input_size = X0.numel();
49  SetTensorDescriptor(cudnnTypeWrapper<T>::type, input_size);
50 
51  // Note: removed Aliasing check, since Output already has
52  // caching capability
53  auto* Y = Output(0, X0.sizes(), at::dtype<T>());
54  T* Y_data = Y->template mutable_data<T>();
55  T alpha = convert::To<float, T>(0.0f);
56  T beta = convert::To<float, T>(0.0f);
57  if (num_inputs == 2) {
58  CopyWeightToHost<T>(weight0.template data<float>(), &alpha);
59  CUDNN_ENFORCE(cudnnAddTensor(
60  cudnn_wrapper_.inline_cudnn_handle(),
61  &alpha,
62  data_desc_,
63  X0.template data<T>(),
65  data_desc_,
66  Y_data));
67  return true;
68  }
69  const auto& X1 = Input(2);
70  CAFFE_ENFORCE(
71  !IsInputOutputAlias(2, 0),
72  "Input #2 is the same as output. If you want to do in-place updates, "
73  "put the output as input #0.");
74  const auto& weight1 = Input(3);
75  CAFFE_ENFORCE_EQ(X1.numel(), input_size);
76  CAFFE_ENFORCE_EQ(weight1.numel(), 1);
77  CopyWeightToHost<T>(weight1.template data<float>(), &alpha);
78  CopyWeightToHost<T>(weight0.template data<float>(), &beta);
79  if (IsInputOutputAlias(0, 0)) {
80  CUDNN_ENFORCE(cudnnAddTensor(
81  cudnn_wrapper_.inline_cudnn_handle(),
82  &alpha,
83  data_desc_,
84  X1.template data<T>(),
85  &beta,
86  data_desc_,
87  Y_data));
88  } else {
89  CUDNN_ENFORCE(cudnnOpTensor(
90  cudnn_wrapper_.inline_cudnn_handle(),
91  add_desc_,
92  &alpha,
93  data_desc_,
94  X1.template data<T>(),
95  &beta,
96  data_desc_,
97  X0.template data<T>(),
99  data_desc_,
100  Y_data));
101  }
102  for (int i = 4; i < num_inputs; i += 2) {
103  const auto& Xi = Input(i);
104  // Do a check: if the input is the same as output, we have a problem -
105  // in-place update should always only happen with the zeroth input.
106  const std::string err_msg = "Input #" + to_string(i) +
107  " is the same as output. If you want to do in-place updates, "
108  "put the output as input #0.";
109  CAFFE_ENFORCE(!IsInputOutputAlias(i, 0), err_msg);
110  const auto& weighti = Input(i + 1);
111  CAFFE_ENFORCE_EQ(Xi.numel(), input_size);
112  CAFFE_ENFORCE_EQ(weighti.numel(), 1);
113  CopyWeightToHost<T>(weighti.template data<float>(), &alpha);
114  CUDNN_ENFORCE(cudnnAddTensor(
115  cudnn_wrapper_.inline_cudnn_handle(),
116  &alpha,
117  data_desc_,
118  Xi.template data<T>(),
120  data_desc_,
121  Y_data));
122  }
123  return true;
124  }
125 
126  private:
127  void SetTensorDescriptor(
128  const cudnnDataType_t data_type,
129  const int input_size) {
130  if (cached_input_size_ != input_size) {
131  cached_input_size_ = input_size;
132  // Since the best performance is obtained when the tesor is HW-packed, we
133  // put X.size() to W.
134  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
135  data_desc_,
136  GetCudnnTensorFormat(StorageOrder::NCHW),
137  data_type,
138  1,
139  1,
140  1,
141  input_size));
142  }
143  }
144 
145  template <typename T>
146  void CopyWeightToHost(const float* src, T* dst);
147 
148  CuDNNWrapper cudnn_wrapper_;
149  cudnnTensorDescriptor_t data_desc_;
150  cudnnOpTensorDescriptor_t add_desc_;
151 
152  int cached_input_size_ = 0;
153 };
154 
155 template <typename T>
156 void CuDNNWeightedSumOp::CopyWeightToHost(const float* src, T* dst) {
157  float val;
158  context_.template CopyToCPU<float>(1, src, &val);
159  *dst = convert::To<float, T>(val);
160 }
161 
162 template <>
163 void CuDNNWeightedSumOp::CopyWeightToHost<float>(const float* src, float* dst) {
164  context_.CopyToCPU<float>(1, src, dst);
165 }
166 
167 REGISTER_CUDNN_OPERATOR(WeightedSum, CuDNNWeightedSumOp);
168 
169 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:192
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnHandle_t inline_cudnn_handle()
Returns the inline cudnn handle that executes on the current thread&#39;s cuda_stream.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:120