Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/core/operator.h"
4 #include "caffe2/core/types.h"
5 
6 namespace caffe2 {
7 
8 class CuDNNLRNOp final : public Operator<CUDAContext> {
9  public:
10  USE_OPERATOR_FUNCTIONS(CUDAContext);
11 
12  template <class... Args>
13  explicit CuDNNLRNOp(Args&&... args)
14  : Operator<CUDAContext>(std::forward<Args>(args)...),
15  cudnn_wrapper_(&context_),
16  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
17  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
18  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
19  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
20  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
21 
22  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
23  CUDNN_ENFORCE(
24  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
25  }
26 
27  ~CuDNNLRNOp() override {
28  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
29  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
30  }
31 
32  template <typename T, typename M>
33  bool DoRunWithType();
34 
35  bool RunOnDevice() override;
36 
37  protected:
38  CuDNNWrapper cudnn_wrapper_;
39  cudnnTensorDescriptor_t data_desc_;
40  cudnnLRNDescriptor_t norm_desc_;
41 
42  vector<int64_t> cudnn_input_dims_;
43 
44  const int size_;
45  const float alpha_;
46  const float beta_;
47  const float bias_;
48 
49  // Input: X, Output: Y
50 };
51 
52 class CuDNNLRNGradientOp final : public Operator<CUDAContext> {
53  public:
54  USE_OPERATOR_FUNCTIONS(CUDAContext);
55  template <class... Args>
56  explicit CuDNNLRNGradientOp(Args&&... args)
57  : Operator<CUDAContext>(std::forward<Args>(args)...),
58  cudnn_wrapper_(&context_),
59  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
60  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
61  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
62  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
63  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
64 
65  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
66  CUDNN_ENFORCE(
67  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
68  }
69 
70  ~CuDNNLRNGradientOp() override {
71  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
72  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
73  }
74 
75  template <typename T, typename M>
76  bool DoRunWithType();
77 
78  bool RunOnDevice() override;
79 
80  protected:
81  CuDNNWrapper cudnn_wrapper_;
82  cudnnTensorDescriptor_t data_desc_;
83  cudnnLRNDescriptor_t norm_desc_;
84 
85  vector<int64_t> cudnn_input_dims_;
86 
87  const int size_;
88  const float alpha_;
89  const float beta_;
90  const float bias_;
91 
92  // Input: X, Y, dY
93  // Output: dX
94 };
95 
96 template <typename T, typename M>
97 bool CuDNNLRNOp::DoRunWithType() {
98  const auto& X = Input(0);
99  auto* Y = Output(0);
100 
101  // Reshape tensor descriptors if necessary
102  if (X.sizes() != cudnn_input_dims_) {
103  VLOG(1) << "Setting descriptors";
104  cudnn_input_dims_ = X.sizes().vec();
105  int C = 1, H = 1, W = 1;
106  // Normal 4-dimensional tensors for images.
107  C = X.dim32(1);
108  H = X.dim32(2);
109  W = X.dim32(3);
110  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
111  data_desc_,
112  GetCudnnTensorFormat(StorageOrder::NCHW),
114  X.dim32(0),
115  C,
116  H,
117  W));
118  }
119 
120  // now actually run the computation
121  CUDNN_ENFORCE(cudnnLRNCrossChannelForward(
122  cudnn_wrapper_.inline_cudnn_handle(),
123  norm_desc_,
124  CUDNN_LRN_CROSS_CHANNEL_DIM1,
126  data_desc_,
127  X.template data<T>(),
129  data_desc_,
130  Y->template mutable_data<T>()));
131 
132  return true;
133 }
134 
135 bool CuDNNLRNOp::RunOnDevice() {
136  // dispatch based on contents of tensor(s)
137  const auto& X = Input(0);
138  auto* Y = Output(0);
139  Y->ResizeLike(X);
140 
141  if (X.IsType<float>()) {
142  return DoRunWithType<float, float>();
143  } else if (X.IsType<at::Half>()) {
144  return DoRunWithType<at::Half, float>();
145  } else {
146  CAFFE_THROW("Unsupported input type");
147  }
148  return false;
149 }
150 
151 template <typename T, typename M>
152 bool CuDNNLRNGradientOp::DoRunWithType() {
153  const auto& X = Input(0);
154  const auto& Y = Input(1);
155  const auto& dY = Input(2);
156  auto* dX = Output(0);
157 
158  if (dY.sizes() != cudnn_input_dims_) {
159  VLOG(1) << "Setting descriptors";
160  cudnn_input_dims_ = dY.sizes().vec();
161  int C = 1, H = 1, W = 1;
162  // Normal 4-dimensional tensors for images.
163  C = dY.dim32(1);
164  H = dY.dim32(2);
165  W = dY.dim32(3);
166  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
167  data_desc_,
168  GetCudnnTensorFormat(StorageOrder::NCHW),
170  dY.dim32(0),
171  C,
172  H,
173  W));
174  }
175 
176  // run the computation
177  CUDNN_ENFORCE(cudnnLRNCrossChannelBackward(
178  cudnn_wrapper_.inline_cudnn_handle(),
179  norm_desc_,
180  CUDNN_LRN_CROSS_CHANNEL_DIM1,
182  data_desc_,
183  Y.template data<T>(),
184  data_desc_,
185  dY.template data<T>(),
186  data_desc_,
187  X.template data<T>(),
189  data_desc_,
190  dX->template mutable_data<T>()));
191  return true;
192 }
193 
194 bool CuDNNLRNGradientOp::RunOnDevice() {
195  // dispatch based on contents of tensor(s)
196  const auto& X = Input(0);
197  const auto& Y = Input(1);
198  const auto& dY = Input(2);
199  auto* dX = Output(0);
200 
201  dX->ResizeLike(dY);
202 
203  if (dY.IsType<float>()) {
204  return DoRunWithType<float, float>();
205  } else if (dY.IsType<at::Half>()) {
206  return DoRunWithType<at::Half, float>();
207  } else {
208  CAFFE_THROW("Unsupported input type");
209  }
210 
211  return false;
212 }
213 
214 namespace {
215 REGISTER_CUDNN_OPERATOR(LRN, CuDNNLRNOp);
216 REGISTER_CUDNN_OPERATOR(LRNGradient, CuDNNLRNGradientOp);
217 }
218 
219 }; // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:192
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:120