Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_op_cudnn.cc
1 
17 #include "caffe2/core/context_gpu.h"
18 #include "caffe2/core/cudnn_wrappers.h"
19 #include "caffe2/operators/conv_op_cache_cudnn.h"
20 #include "caffe2/operators/conv_transpose_op.h"
21 
22 namespace caffe2 {
23 
24 // Earlier in the days Caffe sets the default cudnn workspace to 8MB. We bump
25 // it up to 64MB in Caffe2, as this enables the use of Winograd in many cases,
26 // something very beneficial to more recent CNN models.
27 static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
28 
29 // Manually specified number of algorithms implemented in CuDNN.
30 // This does not have any performance implications, as we will always find the
31 // fastest algorithm; setting them to the right number of algorithms will enable
32 // us to best report the statistics when doing an exhaustive search, though.
33 static constexpr size_t kNUM_CUDNN_FWD_ALGS = 7;
34 static constexpr size_t kNUM_CUDNN_BWD_FILTER_ALGS = 4;
35 static constexpr size_t kNUM_CUDNN_BWD_DATA_ALGS = 5;
36 
37 namespace {
38 template <typename ArrayOfcudnnConvolutionAlgoPerf_t>
39 inline void LogCuDNNPerfStats(
40  const ArrayOfcudnnConvolutionAlgoPerf_t& perf_stat,
41  int returned_algo_count) {
42  LOG(INFO) << "Perf result: (algo: stat, time, memory)";
43  for (int i = 0; i < returned_algo_count; ++i) {
44  const auto& stat = perf_stat[i];
45  LOG(INFO) << stat.algo << ": " << stat.status << " " << stat.time << " "
46  << stat.memory;
47  }
48 }
49 } // namespace
50 
52  public:
53  CudnnConvTransposeOpBase(const OperatorDef& operator_def, Workspace* ws)
54  : ConvTransposeUnpoolBase<CUDAContext>(operator_def, ws),
55  cudnn_wrapper_(&context_),
56  cudnn_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>(
57  "ws_nbytes_limit",
58  kCONV_CUDNN_WORKSPACE_LIMIT_BYTES)),
59  exhaustive_search_(
60  OperatorBase::GetSingleArgument<int>("exhaustive_search", 0)),
61  deterministic_(
62  OperatorBase::GetSingleArgument<int>("deterministic", 0)),
63  cudnn_state_(OperatorBase::GetSingleArgument<int>("cudnn_state", 0)) {
64  CAFFE_ENFORCE(!deterministic_ || !exhaustive_search_);
65  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bottom_desc_));
66  CUDNN_ENFORCE(cudnnCreateFilterDescriptor(&filter_desc_));
67  if (InputSize() == 3) {
68  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&bias_desc_));
69  }
70  CUDNN_ENFORCE(cudnnCreateTensorDescriptor(&top_desc_));
71  CUDNN_ENFORCE(cudnnCreateConvolutionDescriptor(&conv_desc_));
72  }
73 
75  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bottom_desc_));
76  CUDNN_ENFORCE(cudnnDestroyFilterDescriptor(filter_desc_));
77  if (InputSize() == 3) {
78  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(bias_desc_));
79  }
80  CUDNN_ENFORCE(cudnnDestroyTensorDescriptor(top_desc_));
81  CUDNN_ENFORCE(cudnnDestroyConvolutionDescriptor(conv_desc_));
82  }
83 
84  protected:
85  vector<TIndex> cudnn_input_dims_;
86  vector<TIndex> cudnn_filter_dims_;
87 
88  CuDNNWrapper cudnn_wrapper_;
89  cudnnTensorDescriptor_t bottom_desc_;
90  cudnnFilterDescriptor_t filter_desc_;
91  cudnnTensorDescriptor_t bias_desc_;
92  cudnnTensorDescriptor_t top_desc_;
93  cudnnConvolutionDescriptor_t conv_desc_;
94  const size_t cudnn_ws_nbytes_limit_;
95  size_t cudnn_ws_nbytes_;
96  bool exhaustive_search_;
97  bool deterministic_;
98  size_t cudnn_state_;
99 };
100 
101 template <typename T>
103  public:
104  CudnnConvTransposeOp(const OperatorDef& operator_def, Workspace* ws)
105  : CudnnConvTransposeOpBase(operator_def, ws) {}
106 
108 
109  bool RunOnDevice() override;
110 
111  private:
113  cudnnConvolutionBwdDataAlgo_t bwd_data_algo_;
114  // Input: X, W, b
115  // Output: Y
116  INPUT_TAGS(INPUT, FILTER, BIAS);
117 };
118 
119 template <typename T>
121  public:
122  CudnnConvTransposeGradientOp(const OperatorDef& operator_def, Workspace* ws)
123  : CudnnConvTransposeOpBase(operator_def, ws),
124  no_bias_(OperatorBase::GetSingleArgument<bool>("no_bias", false)) {
125  CAFFE_ENFORCE(
126  !(no_bias_ && OutputSize() == 3),
127  "If bias is not present, you should not have 3 grad output.");
128  }
129 
131 
132  bool RunOnDevice() override;
133 
134  private:
135  cudnnConvolutionFwdAlgo_t algo_;
136  cudnnConvolutionBwdFilterAlgo_t bwd_filter_algo_;
137  AlgorithmsCache<cudnnConvolutionFwdAlgo_t> forward_algo_cache_;
139  const bool no_bias_;
140  // input: X, W, dY
141  // output: dW, optionally db and dX
142  INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
143  OUTPUT_TAGS(FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
144 };
145 
147 // Implementations
149 
150 template <typename T>
152  auto& X = Input(INPUT);
153  auto& filter = Input(FILTER);
154  auto* Y = Output(0);
155  int C = 0;
156  switch (order_) {
157  case StorageOrder::NHWC:
158  C = filter.dim32(3);
159  break;
160  case StorageOrder::NCHW:
161  C = filter.dim32(1);
162  break;
163  default:
164  LOG(FATAL) << "Unknown storage order: " << order_;
165  }
167 
168  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
169  switch (order_) {
170  case StorageOrder::NHWC:
171  N = X.dim32(0);
172  H = X.dim32(1);
173  W = X.dim32(2);
174  M = X.dim32(3);
175  H_out = Y->dim32(1);
176  W_out = Y->dim32(2);
177  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
178  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
179  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
180  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
181  break;
182  case StorageOrder::NCHW:
183  N = X.dim32(0);
184  M = X.dim32(1);
185  H = X.dim32(2);
186  W = X.dim32(3);
187  H_out = Y->dim32(2);
188  W_out = Y->dim32(3);
189  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
190  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
191  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
192  break;
193  default:
194  LOG(FATAL) << "Unknown storage order: " << order_;
195  }
196 
197  if (InputSize() == 3) {
198  auto& bias = Input(BIAS);
199  CAFFE_ENFORCE_EQ(bias.ndim(), 1);
200  CAFFE_ENFORCE_EQ(bias.dim32(0), C);
201  }
202 
203  // Set up the cudnn algorithms & workspace if necessary
204  bool input_changed = (X.dims() != cudnn_input_dims_);
205  bool filter_changed = (filter.dims() != cudnn_filter_dims_);
206 
207  if (input_changed || filter_changed) {
208  VLOG(1) << "Changing the cudnn descriptor configurations.";
209  if (input_changed) {
210  cudnn_input_dims_ = X.dims();
211  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
212  bottom_desc_,
213  GetCudnnTensorFormat(order_),
215  N,
216  M,
217  H,
218  W));
219  }
220  if (filter_changed) {
221  cudnn_filter_dims_ = filter.dims();
222  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
223  filter_desc_,
225  GetCudnnTensorFormat(order_),
226  M,
227  C,
228  kernel_h(),
229  kernel_w()));
230  if (InputSize() == 3) {
231  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
232  bias_desc_,
233  GetCudnnTensorFormat(order_),
235  1,
236  C,
237  1,
238  1));
239  }
240  }
241  // Set the output
242  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
243  top_desc_,
244  GetCudnnTensorFormat(order_),
246  N,
247  C,
248  H_out,
249  W_out));
250  // Set the convolution descriptor
251  CAFFE_ENFORCE_EQ(
252  pad_t(),
253  pad_b(),
254  "The current padding scheme leads to unequal padding on the top and "
255  "bottom, which is not supported by cudnn.");
256  CAFFE_ENFORCE_EQ(
257  pad_l(),
258  pad_r(),
259  "The current padding scheme leads to unequal padding on the left "
260  "and right, which is not supported by cudnn.");
261 #if CUDNN_VERSION_MIN(6,0,0)
262  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
263  conv_desc_,
264  pad_t(),
265  pad_l(),
266  stride_h(),
267  stride_w(),
268  1,
269  1,
270  CUDNN_CROSS_CORRELATION,
272 #else
273  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
274  conv_desc_,
275  pad_t(),
276  pad_l(),
277  stride_h(),
278  stride_w(),
279  1,
280  1,
281  CUDNN_CROSS_CORRELATION));
282 #endif
283  if (deterministic_) {
284  bwd_data_algo_ = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
285  } else if (exhaustive_search_) {
286  bwd_data_algo_ =
287  data_algo_cache_.getAlgorithm(X.dims(), filter.dims(), [&]() {
288  int returned_algo_count;
289  std::array<
290  cudnnConvolutionBwdDataAlgoPerf_t,
291  kNUM_CUDNN_BWD_DATA_ALGS>
292  data_perf_stat;
293  cudnn_wrapper_.with_cudnn_state(
294  cudnn_state_, [&](CuDNNState* state) {
295  state->workspace().reset();
296  CUDNN_ENFORCE(cudnnFindConvolutionBackwardDataAlgorithm(
297  state->cudnn_handle(),
298  filter_desc_,
299  bottom_desc_,
300  conv_desc_,
301  top_desc_,
302  kNUM_CUDNN_BWD_DATA_ALGS,
303  &returned_algo_count,
304  data_perf_stat.data()));
305  });
306 
307  LogCuDNNPerfStats(data_perf_stat, returned_algo_count);
308  return data_perf_stat[0].algo;
309  });
310  } else {
311  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataAlgorithm(
312  cudnn_wrapper_.inline_cudnn_handle(),
313  filter_desc_,
314  bottom_desc_,
315  conv_desc_,
316  top_desc_,
317  CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
318  cudnn_ws_nbytes_limit_,
319  &bwd_data_algo_));
320  }
321 
322  size_t bwd_data_ws_size;
323  CUDNN_ENFORCE(cudnnGetConvolutionBackwardDataWorkspaceSize(
324  cudnn_wrapper_.inline_cudnn_handle(),
325  filter_desc_,
326  bottom_desc_,
327  conv_desc_,
328  top_desc_,
329  bwd_data_algo_,
330  &bwd_data_ws_size));
331  cudnn_ws_nbytes_ = bwd_data_ws_size;
332  VLOG(1) << "CuDNN algorithm: " << bwd_data_algo_;
333  VLOG(1) << "CuDNN workspace size: " << bwd_data_ws_size;
334  }
335 
336  // Now, actually run the computation.
337  // Filter
338  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
339  CUDNN_ENFORCE(cudnnConvolutionBackwardData(
340  state->cudnn_handle(),
342  filter_desc_,
343  filter.template data<T>(),
344  bottom_desc_,
345  X.template data<T>(),
346  conv_desc_,
347  bwd_data_algo_,
348  state->workspace().get(cudnn_ws_nbytes_),
349  cudnn_ws_nbytes_,
351  top_desc_,
352  Y->template mutable_data<T>()));
353  });
354  // Bias
355  if (InputSize() == 3) {
356  CUDNN_ENFORCE(cudnnAddTensor(
357  cudnn_wrapper_.inline_cudnn_handle(),
359  bias_desc_,
360  Input(BIAS).template data<T>(),
362  top_desc_,
363  Y->template mutable_data<T>()));
364  }
365  // Done.
366  return true;
367 }
368 
369 // TODO(Yangqing): a lot of the function contents are very similar. Consider
370 // consolidating them.
371 template <typename T>
373  auto& X = Input(INPUT);
374  auto& filter = Input(FILTER);
375  auto& dY = Input(OUTPUT_GRAD);
376  auto* dfilter = Output(FILTER_GRAD);
377  CAFFE_ENFORCE_EQ(X.ndim(), 4);
378  CAFFE_ENFORCE_EQ(filter.ndim(), 4);
379  int C = 0;
380  switch (order_) {
381  case StorageOrder::NHWC:
382  C = filter.dim32(3);
383  break;
384  case StorageOrder::NCHW:
385  C = filter.dim32(1);
386  break;
387  default:
388  LOG(FATAL) << "Unknown storage order: " << order_;
389  }
390 
391  int N = 0, M = 0, H = 0, W = 0, H_out = 0, W_out = 0;
392  switch (order_) {
393  case StorageOrder::NHWC:
394  N = X.dim32(0);
395  H = X.dim32(1);
396  W = X.dim32(2);
397  M = X.dim32(3);
398  H_out = dY.dim32(1);
399  W_out = dY.dim32(2);
400  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
401  CAFFE_ENFORCE_EQ(filter.dim32(1), kernel_h());
402  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_w());
403  CAFFE_ENFORCE_EQ(filter.dim32(3), C);
404  break;
405  case StorageOrder::NCHW:
406  N = X.dim32(0);
407  M = X.dim32(1);
408  H = X.dim32(2);
409  W = X.dim32(3);
410  H_out = dY.dim32(2);
411  W_out = dY.dim32(3);
412  CAFFE_ENFORCE_EQ(filter.dim32(1), C);
413  CAFFE_ENFORCE_EQ(filter.dim32(2), kernel_h());
414  CAFFE_ENFORCE_EQ(filter.dim32(3), kernel_w());
415  break;
416  default:
417  LOG(FATAL) << "Unknown storage order: " << order_;
418  }
419  // Since we only handle LegacyPadding::NOTSET, we don't need to
420  // compute padding.
421  dfilter->ResizeLike(filter);
422 
423  // Set up the cudnn algorithms & workspace if necessary
424  bool input_changed = (X.dims() != cudnn_input_dims_);
425  bool filter_changed = (filter.dims() != cudnn_filter_dims_);
426  if (input_changed || filter_changed) {
427  VLOG(1) << "Changing the cudnn descriptor configurations.";
428  if (input_changed) {
429  cudnn_input_dims_ = X.dims();
430  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
431  bottom_desc_,
432  GetCudnnTensorFormat(order_),
434  N,
435  M,
436  H,
437  W));
438  }
439  if (filter_changed) {
440  cudnn_filter_dims_ = filter.dims();
441  CUDNN_ENFORCE(cudnnSetFilter4dDescriptor(
442  filter_desc_,
444  GetCudnnTensorFormat(order_),
445  M,
446  C,
447  kernel_h(),
448  kernel_w()));
449  if (!no_bias_) {
450  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
451  bias_desc_,
452  GetCudnnTensorFormat(order_),
454  1,
455  C,
456  1,
457  1));
458  }
459  }
460  // Set the output
461  CUDNN_ENFORCE(cudnnSetTensor4dDescriptor(
462  top_desc_,
463  GetCudnnTensorFormat(order_),
465  N,
466  C,
467  H_out,
468  W_out));
469  // Set the convolution descriptor
470  CAFFE_ENFORCE_EQ(
471  pad_t(),
472  pad_b(),
473  "The current padding scheme leads to unequal padding on the top and "
474  "bottom, which is not supported by cudnn.");
475  CAFFE_ENFORCE_EQ(
476  pad_l(),
477  pad_r(),
478  "The current padding scheme leads to unequal padding on the left "
479  "and right, which is not supported by cudnn.");
480 #if CUDNN_VERSION_MIN(6,0,0)
481  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
482  conv_desc_,
483  pad_t(),
484  pad_l(),
485  stride_h(),
486  stride_w(),
487  1,
488  1,
489  CUDNN_CROSS_CORRELATION,
491 #else
492  CUDNN_ENFORCE(cudnnSetConvolution2dDescriptor(
493  conv_desc_,
494  pad_t(),
495  pad_l(),
496  stride_h(),
497  stride_w(),
498  1,
499  1,
500  CUDNN_CROSS_CORRELATION));
501 #endif
502  // Set the workspace
503 
504  size_t bwd_filter_ws_size, fwd_ws_size;
505 
506  if (deterministic_) {
507  algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
508  bwd_filter_algo_ = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
509  } else if (exhaustive_search_) {
510  bwd_filter_algo_ =
511  filter_algo_cache_.getAlgorithm(X.dims(), filter.dims(), [&]() {
512 
513  LOG(INFO) << "CUDNN Convolution bwd: doing exhaustive search.";
514  // When we do an exhaustive search, we will ignore the workspace
515  // size
516  // limit and simply go for the fastest algorithm. If you happen to
517  // run
518  // out of memory later, you will be on your own...
519  int returned_algo_count;
520  // We clean up the current workspace memory so that the forward
521  // algorithm
522  // is free to allocate memory.
523  // Actually run the search.
524  std::array<
525  cudnnConvolutionBwdFilterAlgoPerf_t,
526  kNUM_CUDNN_BWD_FILTER_ALGS>
527  filter_perf_stat;
528 
529  cudnn_wrapper_.with_cudnn_state(
530  cudnn_state_, [&](CuDNNState* state) {
531  state->workspace().reset();
532  CUDNN_ENFORCE(cudnnFindConvolutionBackwardFilterAlgorithm(
533  state->cudnn_handle(),
534  top_desc_,
535  bottom_desc_,
536  conv_desc_,
537  filter_desc_,
538  kNUM_CUDNN_BWD_FILTER_ALGS,
539  &returned_algo_count,
540  filter_perf_stat.data()));
541  });
542  LogCuDNNPerfStats(filter_perf_stat, returned_algo_count);
543  return filter_perf_stat[0].algo;
544  });
545 
546  algo_ = forward_algo_cache_.getAlgorithm(X.dims(), filter.dims(), [&]() {
547  int returned_algo_count;
548  std::array<cudnnConvolutionFwdAlgoPerf_t, kNUM_CUDNN_FWD_ALGS>
549  fwd_perf_stat;
550  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
551  state->workspace().reset();
552  CUDNN_ENFORCE(cudnnFindConvolutionForwardAlgorithm(
553  state->cudnn_handle(),
554  top_desc_,
555  filter_desc_,
556  conv_desc_,
557  bottom_desc_,
558  kNUM_CUDNN_BWD_DATA_ALGS,
559  &returned_algo_count,
560  fwd_perf_stat.data()));
561  });
562 
563  LogCuDNNPerfStats(fwd_perf_stat, returned_algo_count);
564  return fwd_perf_stat[0].algo;
565  });
566  } else {
567  // choose backward algorithm for filter
568  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterAlgorithm(
569  cudnn_wrapper_.inline_cudnn_handle(),
570  top_desc_,
571  bottom_desc_,
572  conv_desc_,
573  filter_desc_,
574  CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
575  cudnn_ws_nbytes_limit_,
576  &bwd_filter_algo_));
577  // choose backward algo for data
578  CUDNN_ENFORCE(cudnnGetConvolutionForwardAlgorithm(
579  cudnn_wrapper_.inline_cudnn_handle(),
580  top_desc_,
581  filter_desc_,
582  conv_desc_,
583  bottom_desc_,
584  CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
585  cudnn_ws_nbytes_limit_,
586  &algo_));
587  }
588  // get workspace for backwards filter algorithm
589  CUDNN_ENFORCE(cudnnGetConvolutionBackwardFilterWorkspaceSize(
590  cudnn_wrapper_.inline_cudnn_handle(),
591  top_desc_,
592  bottom_desc_,
593  conv_desc_,
594  filter_desc_,
595  bwd_filter_algo_,
596  &bwd_filter_ws_size));
597  // get workspace for backwards data algorithm
598  CUDNN_ENFORCE(cudnnGetConvolutionForwardWorkspaceSize(
599  cudnn_wrapper_.inline_cudnn_handle(),
600  top_desc_,
601  filter_desc_,
602  conv_desc_,
603  bottom_desc_,
604  algo_,
605  &fwd_ws_size));
606  cudnn_ws_nbytes_ = std::max(bwd_filter_ws_size, fwd_ws_size);
607 
608  VLOG(1) << "CuDNN bwd algorithm: " << bwd_filter_algo_ << ", " << algo_;
609  VLOG(1) << "CuDNN workspace size: " << cudnn_ws_nbytes_;
610  }
611 
612  // Now, actually run the computation.
613  if (!no_bias_) {
614  auto* dbias = Output(BIAS_OR_INPUT_GRAD);
615  dbias->Resize(C);
616  CUDNN_ENFORCE(cudnnConvolutionBackwardBias(
617  cudnn_wrapper_.inline_cudnn_handle(),
619  top_desc_,
620  dY.template data<T>(),
622  bias_desc_,
623  dbias->template mutable_data<T>()));
624  }
625 
626  cudnn_wrapper_.with_cudnn_state(cudnn_state_, [&](CuDNNState* state) {
627  CUDNN_ENFORCE(cudnnConvolutionBackwardFilter(
628  state->cudnn_handle(),
630  top_desc_,
631  dY.template data<T>(),
632  bottom_desc_,
633  X.template data<T>(),
634  conv_desc_,
635  bwd_filter_algo_,
636  state->workspace().get(cudnn_ws_nbytes_),
637  cudnn_ws_nbytes_,
639  filter_desc_,
640  dfilter->template mutable_data<T>()));
641 
642  if (OutputSize() == 3 || (no_bias_ && (OutputSize() == 2))) {
643  // Compute the gradient w.r.t. the input.
644  auto* dX = Output(no_bias_ ? BIAS_OR_INPUT_GRAD : INPUT_GRAD);
645  dX->ResizeLike(X);
646  CUDNN_ENFORCE(cudnnConvolutionForward(
647  state->cudnn_handle(),
649  top_desc_,
650  dY.template data<T>(),
651  filter_desc_,
652  filter.template data<T>(),
653  conv_desc_,
654  algo_,
655  state->workspace().get(cudnn_ws_nbytes_),
656  cudnn_ws_nbytes_,
658  bottom_desc_,
659  dX->template mutable_data<T>()));
660  }
661  });
662  return true;
663 }
664 
665 REGISTER_CUDNN_OPERATOR(ConvTranspose, CudnnConvTransposeOp<float>);
666 REGISTER_CUDNN_OPERATOR(
667  ConvTransposeGradient,
669 
670 } // namespace caffe2
cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder &order)
A wrapper function to convert the Caffe storage order to cudnn storage order enum values...
Definition: common_cudnn.h:199
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.
CuDNNWrapper is a class that wraps the cudnn handles and cudnn workspaces.
cudnnTypeWrapper is a wrapper class that allows us to refer to the cudnn type in a template function...
Definition: common_cudnn.h:127