Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op_cudnn.cc
1 #include "caffe2/core/context_gpu.h"
2 #include "caffe2/core/cudnn_wrappers.h"
3 #include "caffe2/operators/conv_op_cache_cudnn.h"
4 #include "caffe2/operators/conv_transpose_op.h"
5 #include "caffe2/operators/op_utils_cudnn.h"
6 
7 namespace caffe2 {
8 
9 class CudnnConvTransposeOpBase : public ConvTransposeUnpoolBase<CUDAContext> {
10  public:
11  template <class... Args>
12  explicit CudnnConvTransposeOpBase(Args&&... args)
13  : ConvTransposeUnpoolBase<CUDAContext>(std::forward<Args>(args)...),
14  cudnn_wrapper_(&context_),
15  cudnn_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>(
16  "ws_nbytes_limit",
17  kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
18  exhaustive_search_(
19  OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)),
20  deterministic_(
21  OperatorBase::GetSingleArgument<int>("deterministic", 0)),
22  cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0)),
23  force_algo_(OperatorBase::GetRepeatedArgument<int>(
24  "force_algo",
25  vector<int>{-1, -1, -1})),
26  enable_tensor_core_(
27  OperatorBase::GetSingleArgument<bool>("enable_tensor_core", 1)) {
28  CAFFE_ENFORCE(!deterministic_ || !exhaustive_search_);
29 
30  bool individual_force_algo = OperatorBase::HasArgument("force_algo_fwd") ||
31  OperatorBase::HasArgument("force_algo_dgrad") ||
32  OperatorBase::HasArgument("force_algo_wgrad");
33  if (OperatorBase::HasArgument("force_algo")) {
34  CAFFE_ENFORCE(
35  !individual_force_algo,
36  "Cannot specify both force_algo and any of",
37  "force_algo_fwd, force_algo_dgrad, force_algo_wgrad");
38  } else {
39  force_algo_ = std::vector<int>{-1, -1, -1};
40  force_algo_[ALGO_FWD] =
41  OperatorBase::GetSingleArgument<int>("force_algo_fwd", -1);
42  force_algo_[ALGO_DGRAD] =
43  OperatorBase::GetSingleArgument<int>("force_algo_dgrad", -1);
44  force_algo_[ALGO_WGRAD] =
45  OperatorBase::GetSingleArgument<int>("force_algo_wgrad", -1);
46  }
47 
48  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bottom_desc_));
49  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
50  if (InputSize() == 3) {
51  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
52  }
53  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
54  CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
55  }
56 
57  ~CudnnConvTransposeOpBase() override {
58  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bottom_desc_));
59  CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
60  if (InputSize() == 3) {
61  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
62  }
63  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
64  CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
65  }
66 
67  protected:
68  vector<int64_t> cudnn_input_dims_;
69  vector<int64_t> cudnn_filter_dims_;
70 
71  CuDNNWrapper cudnn_wrapper_;
72  cudnnTensorDescriptor_t bottom_desc_;
73  cudnnFilterDescriptor_t filter_desc_;
74  cudnnTensorDescriptor_t bias_desc_;
75  cudnnTensorDescriptor_t top_desc_;
76  cudnnConvolutionDescriptor_t conv_desc_;
77  const size_t cudnn_ws_nbytes_limit_;
78  size_t cudnn_ws_nbytes_;
79  bool exhaustive_search_;
80  bool deterministic_;
81  size_t cudnn_state_;
82  vector<int> force_algo_; // stored as FWD, dFILTER, dDATA
83  bool enable_tensor_core_;
84 };
85 
86 template <typename T>
88  public:
89  template <class... Args>
90  explicit CudnnConvTransposeOp(Args&&... args)
91  : CudnnConvTransposeOpBase(std::forward<Args>(args)...) {}
92 
93  ~CudnnConvTransposeOp() override {}
94 
95  bool RunOnDevice() override;
96 
97  private:
99  cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
100  // Input: X, W, b
101  // Output: Y
102  INPUT_TAGS(INPUT, FILTER, BIAS);
103 };
104 
105 template <typename T>
107  public:
108  template <class... Args>
109  explicit CudnnConvTransposeGradientOp(Args&&... args)
110  : CudnnConvTransposeOpBase(std::forward<Args>(args)...),
111  no_bias_(OperatorBase::GetSingleArgument<bool>("no_bias", false)) {
112  CAFFE_ENFORCE(
113  !(no_bias_ && OutputSize() == 3),
114  "If bias is not present, you should not have 3 grad output.");
115  }
116 
117  ~CudnnConvTransposeGradientOp() override {}
118 
119  bool RunOnDevice() override;
120 
121  private:
122  cudnnConvolutionFwdAlgo_t algo_;
123  cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
124  AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_algo_cache_;
126  const bool no_bias_;
127  // input: X, W, dY
128  // output: dW, optionally db and dX
129  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
130  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
131 };
132 
134 // Implementations
136 
137 template <typename T>
139  auto& X = Input(INPUT);
140  auto& filter = Input(FILTER);
141  int C = 0;
142  switch (order_) {
143  case StorageOrder::NHWC:
144  C = filter.dim32(3);
145  break;
146  case StorageOrder::NCHW:
147  C = filter.dim32(1);
148  break;
149  default:
150  LOG(FATAL) << "Unknown storage order: " << order_;
151  }
153  auto* Y = Output(0, sizes, at::dtype<T>());
154 
155  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
156  switch (order_) {
157  case StorageOrder::NHWC:
158  N = X.dim32(0);
159  H = X.dim32(1);
160  W = X.dim32(2);
161  M = X.dim32(3);
162  H_out = Y->dim32(1);
163  W_out = Y->dim32(2);
164  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
165  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
166  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
167  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
168  break;
169  case StorageOrder::NCHW:
170  N = X.dim32(0);
171  M = X.dim32(1);
172  H = X.dim32(2);
173  W = X.dim32(3);
174  H_out = Y->dim32(2);
175  W_out = Y->dim32(3);
176  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
177  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
178  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
179  break;
180  default:
181  LOG(FATAL) << "Unknown storage order: " << order_;
182  }
183 
184  if (InputSize() == 3) {
185  auto& bias = Input(BIAS);
186  CAFFE_ENFORCE_EQ(bias.dim(), 1);
187  CAFFE_ENFORCE_EQ(bias.dim32(0), C);
188  }
189 
190  // Set up the cudnn algorithms & workspace if necessary
191  bool input_changed = (X.sizes() != cudnn_input_dims_);
192  bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
193 
194  if (input_changed || filter_changed) {
195  VLOG(1) << "Changing the cudnn descriptor configurations.";
196  if (input_changed) {
197  cudnn_input_dims_ = X.sizes().vec();
198  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
199  bottom_desc_,
200  GetCudnnTensorFormat(order_),
202  N,
203  M,
204  H,
205  W));
206  }
207  if (filter_changed) {
208  cudnn_filter_dims_ = filter.sizes().vec();
209  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
210  filter_desc_,
212  GetCudnnTensorFormat(order_),
213  M,
214  C,
215  kernel_h(),
216  kernel_w()));
217  if (InputSize() == 3) {
218  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
219  bias_desc_,
220  GetCudnnTensorFormat(order_),
222  1,
223  C,
224  1,
225  1));
226  }
227  }
228  // Set the output
229  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
230  top_desc_,
231  GetCudnnTensorFormat(order_),
233  N,
234  C,
235  H_out,
236  W_out));
237  // Set the convolution descriptor
238  CAFFE_ENFORCE_EQ(
239  pad_t(),
240  pad_b(),
241  "The current padding scheme leads to unequal padding on the top and "
242  "bottom, which is not supported by cudnn.");
243  CAFFE_ENFORCE_EQ(
244  pad_l(),
245  pad_r(),
246  "The current padding scheme leads to unequal padding on the left "
247  "and right, which is not supported by cudnn.");
248  // Set the convolution descriptor
249 #if CUDNN_VERSION_MIN(6,0,0)
250  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
251  conv_desc_,
252  pad_t(),
253  pad_l(),
254  stride_h(),
255  stride_w(),
256  1,
257  1,
258  CUDNN_CROSS_CORRELATION,
260 #else
261  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
262  conv_desc_,
263  pad_t(),
264  pad_l(),
265  stride_h(),
266  stride_w(),
267  1,
268  1,
269  CUDNN_CROSS_CORRELATION));
270 #endif
271 #if CUDNN_VERSION_MIN(7, 0, 0)
272  // enable TensorCore math if desired
273  enable_tensor_core_ &= TensorCoreAvailable();
274  if (enable_tensor_core_) {
275  CUDNN_ENFORCE(
276  cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
277  }
278 #endif
279  if (force_algo_[ALGO_DGRAD] >= 0) {
280  bwd_data_algo_ = (cudnnConvolutionBwdDataAlgo_t)force_algo_[ALGO_DGRAD];
281  } else if (deterministic_) {
282  bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
283  } else if (exhaustive_search_) {
284  bwd_data_algo_ =
285  data_algo_cache_.getAlgorithm(X.sizes(), filter.sizes(), 0, [&]() {
286  int returned_algo_count;
287  std::array<
288  cudnnConvolutionBwdDataAlgoPerf_t,
289  kNUM_CUDNN_BWD_DATA_ALGS>
290  data_perf_stat;
291  cudnn_wrapper_.with_cudnn_state(
292  cudnn_state_, [&](CuDNNState* state) {
293  state->workspace().reset();
294  CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithm(
295  state->cudnn_handle(),
296  filter_desc_,
297  bottom_desc_,
298  conv_desc_,
299  top_desc_,
300  kNUM_CUDNN_BWD_DATA_ALGS,
301  &returned_algo_count,
302  data_perf_stat.data()));
303  });
304 
305  LogCuDNNPerfStats(data_perf_stat, returned_algo_count);
306  return data_perf_stat[0].algo;
307  });
308  } else {
309  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm(
310  cudnn_wrapper_.inline_cudnn_handle(),
311  filter_desc_,
312  bottom_desc_,
313  conv_desc_,
314  top_desc_,
315  CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
316  cudnn_ws_nbytes_limit_,
317  &bwd_data_algo_));
318  }
319 
320  size_t bwd_data_ws_size;
321  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataWorkspaceSize(
322  cudnn_wrapper_.inline_cudnn_handle(),
323  filter_desc_,
324  bottom_desc_,
325  conv_desc_,
326  top_desc_,
327  bwd_data_algo_,
328  &bwd_data_ws_size));
329  cudnn_ws_nbytes_ = bwd_data_ws_size;
330  VLOG(1) << "CuDNN algorithm: " << bwd_data_algo_;
331  VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
332  }
333 
334  // Now, actually run the computation.
335  // Filter
336  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
337  CUDNN_ENFORCE(cudnnConvolutionBackwardData(
338  state->cudnn_handle(),
340  filter_desc_,
341  filter.template data<T>(),
342  bottom_desc_,
343  X.template data<T>(),
344  conv_desc_,
345  bwd_data_algo_,
346  state->workspace().get(cudnn_ws_nbytes_),
347  cudnn_ws_nbytes_,
349  top_desc_,
350  Y->template mutable_data<T>()));
351  });
352  // Bias
353  if (InputSize() == 3) {
354  CUDNN_ENFORCE(cudnnAddTensor(
355  cudnn_wrapper_.inline_cudnn_handle(),
357  bias_desc_,
358  Input(BIAS).template data<T>(),
360  top_desc_,
361  Y->template mutable_data<T>()));
362  }
363  // Done.
364  return true;
365 }
366 
367 // TODO(Yangqing): a lot of the function contents are very similar. Consider
368 // consolidating them.
369 template <typename T>
371  auto& X = Input(INPUT);
372  auto& filter = Input(FILTER);
373  auto& dY = Input(OUTPUT_GRAD);
374 
375  CAFFE_ENFORCE_EQ(X.dim(), 4);
376  CAFFE_ENFORCE_EQ(filter.dim(), 4);
377  int C = 0;
378  switch (order_) {
379  case StorageOrder::NHWC:
380  C = filter.dim32(3);
381  break;
382  case StorageOrder::NCHW:
383  C = filter.dim32(1);
384  break;
385  default:
386  LOG(FATAL) << "Unknown storage order: " << order_;
387  }
388 
389  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
390  switch (order_) {
391  case StorageOrder::NHWC:
392  N = X.dim32(0);
393  H = X.dim32(1);
394  W = X.dim32(2);
395  M = X.dim32(3);
396  H_out = dY.dim32(1);
397  W_out = dY.dim32(2);
398  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
399  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
400  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
401  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
402  break;
403  case StorageOrder::NCHW:
404  N = X.dim32(0);
405  M = X.dim32(1);
406  H = X.dim32(2);
407  W = X.dim32(3);
408  H_out = dY.dim32(2);
409  W_out = dY.dim32(3);
410  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
411  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
412  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
413  break;
414  default:
415  LOG(FATAL) << "Unknown storage order: " << order_;
416  }
417  // Since we only handle LegacyPadding::NOTSET, we don't need to
418  // compute padding.
419  auto* dfilter = Output(FILTER_GRAD, filter.sizes(), at::dtype<T>());
420 
421  // Set up the cudnn algorithms & workspace if necessary
422  bool input_changed = (X.sizes() != cudnn_input_dims_);
423  bool filter_changed = (filter.sizes() != cudnn_filter_dims_);
424  if (input_changed || filter_changed) {
425  VLOG(1) << "Changing the cudnn descriptor configurations.";
426  if (input_changed) {
427  cudnn_input_dims_ = X.sizes().vec();
428  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
429  bottom_desc_,
430  GetCudnnTensorFormat(order_),
432  N,
433  M,
434  H,
435  W));
436  }
437  if (filter_changed) {
438  cudnn_filter_dims_ = filter.sizes().vec();
439  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
440  filter_desc_,
442  GetCudnnTensorFormat(order_),
443  M,
444  C,
445  kernel_h(),
446  kernel_w()));
447  if (!no_bias_) {
448  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
449  bias_desc_,
450  GetCudnnTensorFormat(order_),
452  1,
453  C,
454  1,
455  1));
456  }
457  }
458  // Set the output
459  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
460  top_desc_,
461  GetCudnnTensorFormat(order_),
463  N,
464  C,
465  H_out,
466  W_out));
467  // Set the convolution descriptor
468  CAFFE_ENFORCE_EQ(
469  pad_t(),
470  pad_b(),
471  "The current padding scheme leads to unequal padding on the top and "
472  "bottom, which is not supported by cudnn.");
473  CAFFE_ENFORCE_EQ(
474  pad_l(),
475  pad_r(),
476  "The current padding scheme leads to unequal padding on the left "
477  "and right, which is not supported by cudnn.");
478 #if CUDNN_VERSION_MIN(6,0,0)
479  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
480  conv_desc_,
481  pad_t(),
482  pad_l(),
483  stride_h(),
484  stride_w(),
485  1,
486  1,
487  CUDNN_CROSS_CORRELATION,
489 #else
490  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
491  conv_desc_,
492  pad_t(),
493  pad_l(),
494  stride_h(),
495  stride_w(),
496  1,
497  1,
498  CUDNN_CROSS_CORRELATION));
499 #endif
500 #if CUDNN_VERSION_MIN(7, 0, 0)
501  // enable TensorCore math if desired
502  enable_tensor_core_ &= TensorCoreAvailable();
503  if (enable_tensor_core_) {
504  CUDNN_ENFORCE(
505  cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH));
506  }
507 #endif
508  if (force_algo_[ALGO_WGRAD] >= 0) {
509  bwd_filter_algo_ =
510  (cudnnConvolutionBwdFilterAlgo_t)force_algo_[ALGO_WGRAD];
511  } else if (deterministic_) {
512  algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
513  bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
514  } else if (exhaustive_search_) {
515  bwd_filter_algo_ =
516  filter_algo_cache_.getAlgorithm(X.sizes(), filter.sizes(), 0, [&]() {
517  LOG(INFO) << "CUDNN Convolution bwd: doing exhaustive search.";
518  // When we do an exhaustive search, we will ignore the workspace
519  // size
520  // limit and simply go for the fastest algorithm. If you happen to
521  // run
522  // out of memory later, you will be on your own...
523  int returned_algo_count;
524  // We clean up the current workspace memory so that the forward
525  // algorithm
526  // is free to allocate memory.
527  // Actually run the search.
528  std::array<
529  cudnnConvolutionBwdFilterAlgoPerf_t,
530  kNUM_CUDNN_BWD_FILTER_ALGS>
531  filter_perf_stat;
532 
533  cudnn_wrapper_.with_cudnn_state(
534  cudnn_state_, [&](CuDNNState* state) {
535  state->workspace().reset();
536  CUDNN_ENFORCE(cudnnFindConvolutionBackwardFilterAlgorithm(
537  state->cudnn_handle(),
538  top_desc_,
539  bottom_desc_,
540  conv_desc_,
541  filter_desc_,
542  kNUM_CUDNN_BWD_FILTER_ALGS,
543  &returned_algo_count,
544  filter_perf_stat.data()));
545  });
546  LogCuDNNPerfStats(filter_perf_stat, returned_algo_count);
547  return filter_perf_stat[0].algo;
548  });
549 
550  algo_ =
551  forward_algo_cache_.getAlgorithm(X.sizes(), filter.sizes(), 0, [&]() {
552  int returned_algo_count;
553  std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
554  fwd_perf_stat;
555  cudnn_wrapper_.with_cudnn_state(
556  cudnn_state_, [&](CuDNNState* state) {
557  state->workspace().reset();
558  CUDNN_ENFORCE(cudnnFindConvolutionForwardAlgorithm(
559  state->cudnn_handle(),
560  top_desc_,
561  filter_desc_,
562  conv_desc_,
563  bottom_desc_,
564  kNUM_CUDNN_BWD_DATA_ALGS,
565  &returned_algo_count,
566  fwd_perf_stat.data()));
567  });
568 
569  LogCuDNNPerfStats(fwd_perf_stat, returned_algo_count);
570  return fwd_perf_stat[0].algo;
571  });
572  } else {
573  // choose backward algorithm for filter
574  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterAlgorithm(
575  cudnn_wrapper_.inline_cudnn_handle(),
576  top_desc_,
577  bottom_desc_,
578  conv_desc_,
579  filter_desc_,
580  CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
581  cudnn_ws_nbytes_limit_,
582  &bwd_filter_algo_));
583  // choose backward algo for data
584  CUDNN_ENFORCE(cudnnGetConvolutionForwardAlgorithm(
585  cudnn_wrapper_.inline_cudnn_handle(),
586  top_desc_,
587  filter_desc_,
588  conv_desc_,
589  bottom_desc_,
590  CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
591  cudnn_ws_nbytes_limit_,
592  &algo_));
593  }
594  // get workspace for backwards filter algorithm
595  size_t bwd_filter_ws_size, fwd_ws_size;
596  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterWorkspaceSize(
597  cudnn_wrapper_.inline_cudnn_handle(),
598  top_desc_,
599  bottom_desc_,
600  conv_desc_,
601  filter_desc_,
602  bwd_filter_algo_,
603  &bwd_filter_ws_size));
604  // get workspace for backwards data algorithm
605  CUDNN_ENFORCE(cudnnGetConvolutionForwardWorkspaceSize(
606  cudnn_wrapper_.inline_cudnn_handle(),
607  top_desc_,
608  filter_desc_,
609  conv_desc_,
610  bottom_desc_,
611  algo_,
612  &fwd_ws_size));
613  cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, fwd_ws_size);
614 
615  VLOG(1) << "CuDNN bwd algorithm: " << bwd_filter_algo_ << ", " << algo_;
616  VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
617  }
618 
619  // Now, actually run the computation.
620  if (!no_bias_) {
621  auto* dbias = Output(BIAS_OR_INPUT_GRAD, {C}, at::dtype<T>());
622  CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
623  cudnn_wrapper_.inline_cudnn_handle(),
625  top_desc_,
626  dY.template data<T>(),
628  bias_desc_,
629  dbias->template mutable_data<T>()));
630  }
631 
632  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
633  CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
634  state->cudnn_handle(),
636  top_desc_,
637  dY.template data<T>(),
638  bottom_desc_,
639  X.template data<T>(),
640  conv_desc_,
641  bwd_filter_algo_,
642  state->workspace().get(cudnn_ws_nbytes_),
643  cudnn_ws_nbytes_,
645  filter_desc_,
646  dfilter->template mutable_data<T>()));
647 
648  if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
649  // Compute the gradient w.r.t. the input.
650 
651  auto* dX = Output(
652  no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD,
653  X.sizes(),
654  at::dtype<T>());
655  CUDNN_ENFORCE(cudnnConvolutionForward(
656  state->cudnn_handle(),
658  top_desc_,
659  dY.template data<T>(),
660  filter_desc_,
661  filter.template data<T>(),
662  conv_desc_,
663  algo_,
664  state->workspace().get(cudnn_ws_nbytes_),
665  cudnn_ws_nbytes_,
667  bottom_desc_,
668  dX->template mutable_data<T>()));
669  }
670  });
671  return true;
672 }
673 
674 REGISTER_CUDNN_OPERATOR(ConvTranspose, CudnnConvTransposeOp<float>);
675 REGISTER_CUDNN_OPERATOR(
676  ConvTransposeGradient,
678 
679 } // namespace caffe2
Definition: any.cpp:108
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:192
bool TensorCoreAvailable()
Return the availability of TensorCores for math.
Definition: common_gpu.cc:217
const Tensor & Input(int idx, DeviceType type=CUDAContext::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:120