Caffe2 - C++ API
A deep learning, cross platform ML framework
local_response_normalization_op_cudnn.cc
1 
17 #include "caffe2/core/context_gpu.h"
18 #include "caffe2/core/cudnn_wrappers.h"
19 #include "caffe2/core/operator.h"
20 #include "caffe2/core/types.h"
21 
22 namespace caffe2 {
23 
24 class CuDNNLRNOp final : public Operator<CUDAContext> {
25  public:
26  USE_OPERATOR_FUNCTIONS(CUDAContext);
27 
28  CuDNNLRNOp(const OperatorDef& operator_def, Workspace* ws)
29  : Operator<CUDAContext>(operator_def, ws),
30  cudnn_wrapper_(&context_),
31  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
32  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
33  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
34  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
35  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
36 
37  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
38  CUDNN_ENFORCE(
39  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
40  }
41 
42  ~CuDNNLRNOp() {
43  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
44  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
45  }
46 
47  template <typename T, typename M>
48  bool DoRunWithType();
49 
50  bool RunOnDevice() override;
51 
52  protected:
53  CuDNNWrapper cudnn_wrapper_;
54  cudnnTensorDescriptor_t data_desc_;
55  cudnnLRNDescriptor_t norm_desc_;
56 
57  vector<TIndex> cudnn_input_dims_;
58 
59  const int size_;
60  const float alpha_;
61  const float beta_;
62  const float bias_;
63 
64  // Input: X, Output: Y
65 };
66 
67 class CuDNNLRNGradientOp final : public Operator<CUDAContext> {
68  public:
69  USE_OPERATOR_FUNCTIONS(CUDAContext);
70  CuDNNLRNGradientOp(const OperatorDef& operator_def, Workspace* ws)
71  : Operator<CUDAContext>(operator_def, ws),
72  cudnn_wrapper_(&context_),
73  size_(OperatorBase::GetSingleArgument<int>("size", 0)),
74  alpha_(OperatorBase::GetSingleArgument<float>("alpha", 0)),
75  beta_(OperatorBase::GetSingleArgument<float>("beta", 0)),
76  bias_(OperatorBase::GetSingleArgument<float>("bias", 1)) {
77  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&data_desc_));
78 
79  CUDNN_ENFORCE(cudnnCreateLRNDescriptor(&norm_desc_));
80  CUDNN_ENFORCE(
81  cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, bias_));
82  }
83 
85  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(data_desc_));
86  CUDNN_ENFORCE(cudnnDestroyLRNDescriptor(norm_desc_));
87  }
88 
89  template <typename T, typename M>
90  bool DoRunWithType();
91 
92  bool RunOnDevice() override;
93 
94  protected:
95  CuDNNWrapper cudnn_wrapper_;
96  cudnnTensorDescriptor_t data_desc_;
97  cudnnLRNDescriptor_t norm_desc_;
98 
99  vector<TIndex> cudnn_input_dims_;
100 
101  const int size_;
102  const float alpha_;
103  const float beta_;
104  const float bias_;
105 
106  // Input: X, Y, dY
107  // Output: dX
108 };
109 
110 template <typename T, typename M>
111 bool CuDNNLRNOp::DoRunWithType() {
112  const auto& X = Input(0);
113  auto* Y = Output(0);
114 
115  // Reshape tensor descriptors if necessary
116  if (X.dims() != cudnn_input_dims_) {
117  VLOG(1) << "Setting descriptors";
118  cudnn_input_dims_ = X.dims();
119  int C = 1, H = 1, W = 1;
120  // Normal 4-dimensional tensors for images.
121  C = X.dim32(1);
122  H = X.dim32(2);
123  W = X.dim32(3);
124  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
125  data_desc_,
126  GetCudnnTensorFormat(StorageOrder::NCHW),
128  X.dim32(0),
129  C,
130  H,
131  W));
132  }
133 
134  // now actually run the computation
135  CUDNN_ENFORCE(cudnnLRNCrossChannelForward(
136  cudnn_wrapper_.inline_cudnn_handle(),
137  norm_desc_,
138  CUDNN_LRN_CROSS_CHANNEL_DIM1,
140  data_desc_,
141  X.template data<T>(),
143  data_desc_,
144  Y->template mutable_data<T>()));
145 
146  return true;
147 }
148 
149 bool CuDNNLRNOp::RunOnDevice() {
150  // dispatch based on contents of tensor(s)
151  const auto& X = Input(0);
152  auto* Y = Output(0);
153  Y->ResizeLike(X);
154 
155  if (X.IsType<float>()) {
156  return DoRunWithType<float, float>();
157  } else if (X.IsType<float16>()) {
158  return DoRunWithType<float16, float>();
159  } else {
160  CAFFE_THROW("Unsupported input type");
161  }
162  return false;
163 }
164 
165 template <typename T, typename M>
166 bool CuDNNLRNGradientOp::DoRunWithType() {
167  const auto& X = Input(0);
168  const auto& Y = Input(1);
169  const auto& dY = Input(2);
170  auto* dX = Output(0);
171 
172  if (dY.dims() != cudnn_input_dims_) {
173  VLOG(1) << "Setting descriptors";
174  cudnn_input_dims_ = dY.dims();
175  int C = 1, H = 1, W = 1;
176  // Normal 4-dimensional tensors for images.
177  C = dY.dim32(1);
178  H = dY.dim32(2);
179  W = dY.dim32(3);
180  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
181  data_desc_,
182  GetCudnnTensorFormat(StorageOrder::NCHW),
184  dY.dim32(0),
185  C,
186  H,
187  W));
188  }
189 
190  // run the computation
191  CUDNN_ENFORCE(cudnnLRNCrossChannelBackward(
192  cudnn_wrapper_.inline_cudnn_handle(),
193  norm_desc_,
194  CUDNN_LRN_CROSS_CHANNEL_DIM1,
196  data_desc_,
197  Y.template data<T>(),
198  data_desc_,
199  dY.template data<T>(),
200  data_desc_,
201  X.template data<T>(),
203  data_desc_,
204  dX->template mutable_data<T>()));
205  return true;
206 }
207 
208 bool CuDNNLRNGradientOp::RunOnDevice() {
209  // dispatch based on contents of tensor(s)
210  const auto& X = Input(0);
211  const auto& Y = Input(1);
212  const auto& dY = Input(2);
213  auto* dX = Output(0);
214 
215  dX->ResizeLike(dY);
216 
217  if (dY.IsType<float>()) {
218  return DoRunWithType<float, float>();
219  } else if (dY.IsType<float16>()) {
220  return DoRunWithType<float16, float>();
221  } else {
222  CAFFE_THROW("Unsupported input type");
223  }
224 
225  return false;
226 }
227 
228 namespace {
229 REGISTER_CUDNN_OPERATOR(LRN, CuDNNLRNOp);
230 REGISTER_CUDNN_OPERATOR(LRNGradient, CuDNNLRNGradientOp);
231 }
232 
233 }; // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:199
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:127