2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 4 #include <ATen/cuda/CUDAConfig.h> 5 #include <ATen/cuda/Exceptions.h> 7 #if !AT_CUDNN_ENABLED() 9 namespace at {
namespace native {
15 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
16 int64_t groups,
bool benchmark,
bool deterministic) {
17 AT_ERROR(
"cudnn_convolution: ATen not compiled with cuDNN support");
22 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
23 bool benchmark,
bool deterministic) {
24 AT_ERROR(
"cudnn_convolution_backward_input: ATen not compiled with cuDNN support");
29 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
30 bool benchmark,
bool deterministic) {
31 AT_ERROR(
"cudnn_convolution_backward_weight: ATen not compiled with cuDNN support");
36 AT_ERROR(
"cudnn_convolution_backward_bias: ATen not compiled with cuDNN support");
39 std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_backward(
41 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
42 bool benchmark,
bool deterministic, std::array<bool,3> output_mask) {
43 AT_ERROR(
"cudnn_convolution_backward: ATen not compiled with cuDNN support");
48 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
49 int64_t groups,
bool benchmark,
bool deterministic) {
50 AT_ERROR(
"cudnn_convolution_transpose: ATen not compiled with cuDNN support");
53 at::Tensor cudnn_convolution_transpose_backward_input(
55 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
56 int64_t groups,
bool benchmark,
bool deterministic) {
57 AT_ERROR(
"cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
60 at::Tensor cudnn_convolution_transpose_backward_weight(
62 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
63 bool benchmark,
bool deterministic) {
64 AT_ERROR(
"cudnn_convolution_transpose_backward_weight: ATen not compiled with cuDNN support");
67 std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
69 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
70 bool benchmark,
bool deterministic, std::array<bool,3> output_mask) {
71 AT_ERROR(
"cudnn_convolution_transpose_backward: ATen not compiled with cuDNN support");
76 #else // AT_CUDNN_ENABLED 80 #include <ATen/cudnn/cudnn-wrapper.h> 81 #include <ATen/cudnn/Descriptors.h> 82 #include <ATen/cudnn/Types.h> 83 #include <ATen/cudnn/Utils.h> 84 #include <ATen/native/utils/ParamsHash.h> 86 #include <ATen/TensorUtils.h> 95 #include <unordered_map> 120 namespace at {
namespace native {
131 constexpr
int input_batch_size_dim = 0;
132 constexpr
int input_channels_dim = 1;
133 constexpr
int output_batch_size_dim = 0;
134 constexpr
int output_channels_dim = 1;
135 constexpr
int weight_output_channels_dim = 0;
136 constexpr
int weight_input_channels_dim = 1;
139 constexpr
int max_dim = 3;
145 static std::vector<int64_t> conv_output_size(
146 IntArrayRef input_size, IntArrayRef weight_size,
147 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
151 auto dim = input_size.size();
152 std::vector<int64_t> output_size(dim);
153 output_size[0] = input_size[input_batch_size_dim];
154 output_size[1] = weight_size[weight_output_channels_dim];
155 for (
size_t d = 2; d < dim; ++d) {
156 auto kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
157 output_size[d] = (input_size[d] + (2 * padding[d - 2])
158 - kernel) / stride[d - 2] + 1;
163 std::vector<int64_t> conv_input_size(
164 IntArrayRef output_size, IntArrayRef weight_size,
165 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
169 auto dim = output_size.size();
170 std::vector<int64_t> input_size(dim);
171 input_size[0] = output_size[output_batch_size_dim];
172 input_size[1] = weight_size[weight_input_channels_dim] * groups;
173 for (
size_t d = 2; d < dim; ++d) {
174 int kernel = dilation[d - 2] * (weight_size[d] - 1) + 1;
175 input_size[d] = (output_size[d] - 1) * stride[d - 2] - (2 * padding[d - 2]) +
176 kernel + output_padding[d - 2];
181 std::vector<int64_t> conv_weight_size(
182 IntArrayRef input_size, IntArrayRef output_size,
183 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups
185 auto dim = input_size.size();
186 std::vector<int64_t> weight_size(dim);
187 weight_size[0] = output_size[1];
188 weight_size[1] = input_size[1] / groups;
189 for (
size_t d = 2; d < dim; ++d) {
190 int kernel = input_size[d] - (output_size[d] - 1) * stride[d - 2]
191 + 2 * padding[d - 2] - output_padding[d - 2];
192 weight_size[d] = (kernel - 1) / dilation[d - 2] + 1;
198 Tensor narrowGroup(
const Tensor& t,
int dim,
int group_idx, int64_t groups) {
199 auto group_size = t.size(dim) / groups;
200 return t.narrow(dim, group_idx * group_size, group_size);
218 static void check_args(CheckedFrom c, IntArrayRef args,
size_t expected_size,
const char* arg_name)
220 AT_CHECK(args.size() <= expected_size,
221 "Too many ", arg_name,
" values (", args.size(),
") supplied, expecting ",
222 expected_size,
" (while checking arguments for ", c,
")");
223 AT_CHECK(args.size() >= expected_size,
224 "Not enough ", arg_name,
" values (", args.size(),
") supplied, expecting ",
225 expected_size,
" (while checking arguments for ", c,
")");
227 auto num_negative_values = std::count_if(args.begin(), args.end(), [](
int x){
return x < 0;});
228 if (num_negative_values > 0){
229 std::stringstream ss;
230 ss << arg_name <<
" should be greater than zero but got (";
231 std::copy(args.begin(), args.end() - 1, std::ostream_iterator<int>(ss,
", "));
232 ss << args.back() <<
")" <<
" (while checking arguments for " << c <<
")";
253 static void convolution_shape_check(
255 const TensorGeometryArg& input,
const TensorGeometryArg& weight,
const TensorGeometryArg& output,
256 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups)
258 check_args(c, padding, input->dim() - 2,
"padding");
259 check_args(c, stride, padding.size(),
"stride");
260 check_args(c, dilation, padding.size(),
"dilation");
263 checkDimRange(c, input, 3, 6 );
264 checkSize(c, input, input_channels_dim, weight->size(1) * groups);
267 checkSameDim(c, input, weight);
271 checkSameDim(c, input, output);
276 struct ConvolutionParams
278 cudnnDataType_t dataType;
279 int input_size[2 + max_dim];
280 int input_stride[2 + max_dim];
281 int weight_size[2 + max_dim];
282 int padding[max_dim];
284 int dilation[max_dim];
296 void setConvolutionParams(
297 ConvolutionParams* params,
299 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
300 int64_t groups,
bool deterministic) {
302 cudnnDataType_t dataType = getCudnnDataType(input);
303 memset(params, 0,
sizeof(ConvolutionParams));
304 params->dataType = dataType;
306 for (
int i = 0; i != input.dim(); ++i) {
307 params->input_size[i] = (int) input.size(i);
308 params->input_stride[i] = (int) input.stride(i);
309 params->weight_size[i] = (int) weight.size(i);
313 for (
size_t i = 0; i != padding.size(); ++i) {
314 params->padding[i] = padding[i];
315 params->stride[i] = stride[i];
316 params->dilation[i] = dilation[i];
320 params->groups = groups;
321 params->deterministic = deterministic;
326 struct ConvolutionArgs {
327 cudnnHandle_t handle;
328 ConvolutionParams params;
329 TensorDescriptor idesc, odesc;
330 FilterDescriptor wdesc;
331 const Tensor& input, output, weight;
332 ConvolutionDescriptor cdesc;
334 ConvolutionArgs(
const Tensor& input,
const Tensor& output,
const Tensor& weight) : input(input), output(output), weight(weight) {
345 template <
typename T>
346 struct BenchmarkCache {
348 std::unordered_map<ConvolutionParams, T, ParamsHash<ConvolutionParams>, ParamsEqual<ConvolutionParams>> map;
350 bool find(
const ConvolutionParams& params,
T* results) {
351 std::lock_guard<std::mutex> guard(mutex);
352 auto it = map.find(params);
353 if (it == map.end()) {
356 *results = it->second;
360 void insert(
const ConvolutionParams& params,
const T& results) {
361 std::lock_guard<std::mutex> guard(mutex);
362 map[params] = results;
366 BenchmarkCache<cudnnConvolutionFwdAlgoPerf_t> fwd_algos;
367 BenchmarkCache<cudnnConvolutionBwdDataAlgoPerf_t> bwd_data_algos;
368 BenchmarkCache<cudnnConvolutionBwdFilterAlgoPerf_t> bwd_filter_algos;
373 Workspace(
size_t size) : size(size), data(NULL) {
374 data = THCudaMalloc(globalContext().lazyInitCUDA(), size);
376 Workspace(
const Workspace&) =
delete;
377 Workspace(Workspace&&) =
default;
378 Workspace& operator=(Workspace&&) =
default;
381 THCudaFree(globalContext().lazyInitCUDA(), data);
389 template<
typename perf_t>
390 struct algorithm_search {
393 cudnnStatus_t getWorkspaceSize(
394 const ConvolutionArgs& args,
395 cudnnConvolutionFwdAlgo_t algo,
size_t* sz)
397 return cudnnGetConvolutionForwardWorkspaceSize(
407 cudnnStatus_t getWorkspaceSize(
408 const ConvolutionArgs& args,
409 cudnnConvolutionBwdDataAlgo_t algo,
size_t* sz)
411 return cudnnGetConvolutionBackwardDataWorkspaceSize(
420 cudnnStatus_t getWorkspaceSize(
421 const ConvolutionArgs& args,
422 cudnnConvolutionBwdFilterAlgo_t algo,
size_t* sz)
424 return cudnnGetConvolutionBackwardFilterWorkspaceSize(
434 template<
typename algo_t>
435 size_t getMaxWorkspaceSize(
436 const ConvolutionArgs& args,
437 const algo_t *algo,
int n_algo)
439 THCState *state = globalContext().lazyInitCUDA();
441 size_t max_ws_size = 0;
442 size_t max_block_size = 0;
443 size_t total_gpu_mem = 0;
444 size_t free_gpu_mem = 0;
446 THCudaCheck(THCudaMemGetInfo(state, &free_gpu_mem, &total_gpu_mem, &max_block_size));
448 for (
int i = 0; i < n_algo; i++) {
451 err = getWorkspaceSize(args, algo[i], &sz);
452 if (CUDNN_STATUS_SUCCESS != err || sz == 0
453 || sz < max_ws_size || sz > max_block_size)
continue;
459 template<
typename perf_t>
460 perf_t getBestAlgorithm(perf_t *perfResults,
const ConvolutionArgs& args,
int n_algo) {
462 bool is_deterministic =
false;
463 if (args.params.deterministic) {
465 for (
int i = 0; i < n_algo; i++) {
468 if (perfResults[i].status == CUDNN_STATUS_SUCCESS &&
469 perfResults[i].determinism == CUDNN_DETERMINISTIC) {
471 is_deterministic =
true;
475 if (!is_deterministic) {
476 AT_ERROR(
"no deterministic convolution algorithms available in CuDNN");
483 #if CUDNN_VERSION < 7500 484 if (std::is_same<decltype(perfResults[best_algo_idx].algo), cudnnConvolutionBwdDataAlgo_t>::value) {
485 int stride_dim = args.input.dim() - 2;
486 bool blacklist = std::any_of(std::begin(args.params.stride),
487 std::begin(args.params.stride) + stride_dim,
488 [=](
int n){return n != 1;});
489 if (blacklist && (static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[best_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING
490 || static_cast<cudnnConvolutionBwdDataAlgo_t>(perfResults[best_algo_idx].algo) == CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT)) {
491 perfResults[best_algo_idx].algo = algorithm_search<perf_t>::DEFAULT_ALGO;
492 if (args.params.dataType == CUDNN_DATA_HALF) {
493 perfResults[best_algo_idx].mathType = CUDNN_TENSOR_OP_MATH;
495 perfResults[best_algo_idx].mathType = CUDNN_DEFAULT_MATH;
501 return perfResults[best_algo_idx];
505 struct algorithm_search<cudnnConvolutionFwdAlgoPerf_t> {
506 using perf_t = cudnnConvolutionFwdAlgoPerf_t;
507 using algo_t = cudnnConvolutionFwdAlgo_t;
509 static constexpr
auto DEFAULT_ALGO = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
510 static BenchmarkCache<perf_t>& cache() {
return fwd_algos; }
512 static perf_t findAlgorithm(
const ConvolutionArgs& args,
bool benchmark) {
513 static const algo_t algos[] = {
514 CUDNN_CONVOLUTION_FWD_ALGO_GEMM,
515 CUDNN_CONVOLUTION_FWD_ALGO_FFT,
516 CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING,
517 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM,
518 CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM,
519 CUDNN_CONVOLUTION_FWD_ALGO_DIRECT,
520 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
521 CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED,
523 static constexpr
int num_algos = CUDNN_CONVOLUTION_FWD_ALGO_COUNT;
524 static_assert(
sizeof(algos) /
sizeof(algos[0]) == num_algos,
525 "Missing cuDNN convolution forward algorithms");
527 std::unique_ptr<perf_t[]> perf_results(
new perf_t[num_algos]);
529 AT_CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm_v7(
537 perf_results.get()));
539 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
540 Workspace ws(max_ws_size);
541 AT_CUDNN_CHECK(cudnnFindConvolutionForwardAlgorithmEx(
543 args.idesc.desc(), args.input.data_ptr(),
544 args.wdesc.desc(), args.weight.data_ptr(),
546 args.odesc.desc(), args.output.data_ptr(),
553 return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
556 static void getWorkspaceSize(
557 const ConvolutionArgs& args,
558 algo_t algo,
size_t* workspaceSize)
560 AT_CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
572 struct algorithm_search<cudnnConvolutionBwdDataAlgoPerf_t> {
573 using perf_t = cudnnConvolutionBwdDataAlgoPerf_t;
574 using algo_t = cudnnConvolutionBwdDataAlgo_t;
576 static constexpr
auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_DATA_ALGO_1;
577 static BenchmarkCache<perf_t>& cache() {
return bwd_data_algos; }
579 static perf_t findAlgorithm(
const ConvolutionArgs& args,
bool benchmark) {
580 static const algo_t algos[] = {
581 CUDNN_CONVOLUTION_BWD_DATA_ALGO_0,
582 CUDNN_CONVOLUTION_BWD_DATA_ALGO_1,
583 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT,
584 CUDNN_CONVOLUTION_BWD_DATA_ALGO_FFT_TILING,
585 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD,
586 CUDNN_CONVOLUTION_BWD_DATA_ALGO_WINOGRAD_NONFUSED
588 static constexpr
int num_algos = CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT;
589 static_assert(
sizeof(algos) /
sizeof(algos[0]) == num_algos,
590 "Missing cuDNN convolution backward data algorithms.");
592 std::unique_ptr<perf_t[]> perf_results(
new perf_t[num_algos]);
594 AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm_v7(
602 perf_results.get()));
604 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
605 Workspace ws(max_ws_size);
606 AT_CUDNN_CHECK(cudnnFindConvolutionBackwardDataAlgorithmEx(
608 args.wdesc.desc(), args.weight.data_ptr(),
609 args.odesc.desc(), args.output.data_ptr(),
611 args.idesc.desc(), args.input.data_ptr(),
618 return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
621 static void getWorkspaceSize(
622 const ConvolutionArgs& args,
623 cudnnConvolutionBwdDataAlgo_t algo,
size_t* workspaceSize)
625 AT_CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
637 struct algorithm_search<cudnnConvolutionBwdFilterAlgoPerf_t> {
638 using perf_t = cudnnConvolutionBwdFilterAlgoPerf_t;
639 using algo_t = cudnnConvolutionBwdFilterAlgo_t;
641 static constexpr
auto DEFAULT_ALGO = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1;
643 static BenchmarkCache<perf_t>& cache() {
return bwd_filter_algos; }
645 static perf_t findAlgorithm(
const ConvolutionArgs& args,
bool benchmark) {
646 static const algo_t algos[] = {
647 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0,
648 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_1,
649 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT,
650 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_3,
651 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_WINOGRAD_NONFUSED,
652 CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFT_TILING,
655 static constexpr
int num_algos = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT - 1;
656 static_assert(
sizeof(algos) /
sizeof(algos[0]) == num_algos,
657 "Missing cuDNN convolution backward filter algorithms.");
658 std::unique_ptr<perf_t[]> perf_results(
new perf_t[num_algos]);
661 AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm_v7(
669 perf_results.get()));
671 size_t max_ws_size = getMaxWorkspaceSize(args, algos, num_algos);
672 Workspace ws(max_ws_size);
673 AT_CUDNN_CHECK(cudnnFindConvolutionBackwardFilterAlgorithmEx(
675 args.idesc.desc(), args.input.data_ptr(),
676 args.odesc.desc(), args.output.data_ptr(),
678 args.wdesc.desc(), args.weight.data_ptr(),
685 return getBestAlgorithm<perf_t>(perf_results.get(), args, perf_count);
688 static void getWorkspaceSize(
const ConvolutionArgs& args, algo_t algo,
size_t* workspaceSize)
690 AT_CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
701 template<
typename perf_t>
702 void findAlgorithm(
const ConvolutionArgs& args,
bool benchmark, perf_t* algoPerf) {
703 using search = algorithm_search<perf_t>;
704 auto& cache = search::cache();
706 if (cache.find(args.params, algoPerf)) {
710 if (args.params.deterministic && !benchmark) {
711 algoPerf->algo = search::DEFAULT_ALGO;
712 if (args.params.dataType == CUDNN_DATA_HALF) {
713 algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
715 algoPerf->mathType = CUDNN_DEFAULT_MATH;
717 search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
722 if (cache.find(args.params, algoPerf)) {
728 auto perfResults = search::findAlgorithm(args, benchmark);
731 if (perfResults.status == CUDNN_STATUS_SUCCESS &&
732 !(args.params.deterministic && perfResults.determinism != CUDNN_DETERMINISTIC)) {
736 cache.insert(args.params, perfResults);
741 c10::cuda::CUDACachingAllocator::emptyCache();
744 *algoPerf = perfResults;
746 algoPerf->algo = search::DEFAULT_ALGO;
747 if (args.params.dataType == CUDNN_DATA_HALF) {
748 algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
750 algoPerf->mathType = CUDNN_DEFAULT_MATH;
752 search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
756 template<
typename perf_t>
757 Workspace chooseAlgorithm(
758 const ConvolutionArgs& args,
762 findAlgorithm(args, benchmark, algoPerf);
764 using search = algorithm_search<perf_t>;
766 return Workspace(algoPerf->memory);
767 }
catch (
const std::exception& e) {
772 algoPerf->algo = search::DEFAULT_ALGO;
773 if (args.params.dataType == CUDNN_DATA_HALF) {
774 algoPerf->mathType = CUDNN_TENSOR_OP_MATH;
776 algoPerf->mathType = CUDNN_DEFAULT_MATH;
778 search::getWorkspaceSize(args, algoPerf->algo, &(algoPerf->memory));
779 search::cache().insert(args.params, *algoPerf);
780 return Workspace(algoPerf->memory);
791 void cudnn_convolution_add_bias_(CheckedFrom c,
const TensorArg& output,
const TensorArg& bias)
793 checkAllSameType(c, {output, bias});
794 checkAllSameGPU(c, {output, bias});
795 checkSize(c, bias, { output->size(output_channels_dim) });
799 TensorDescriptor bdesc, odesc;
800 bdesc.set(bias->expand({1, bias->size(0)}), output->dim());
803 auto handle = getCudnnHandle();
804 auto dataType = getCudnnDataType(*bias);
805 Constant one(dataType, 1);
807 AT_CUDNN_CHECK(cudnnAddTensor(handle, &one, bdesc.desc(), bias->data_ptr(),
808 &one, odesc.desc(), output->data_ptr()));
854 void raw_cudnn_convolution_forward_out(
856 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
857 bool benchmark,
bool deterministic) {
859 auto dataType = getCudnnDataType(input);
861 ConvolutionArgs args{ input, output, weight };
862 args.handle = getCudnnHandle();
863 setConvolutionParams(&args.params, input, weight, padding, stride, dilation, groups, deterministic);
864 args.idesc.set(input);
865 args.wdesc.set(weight);
866 args.odesc.set(output);
867 args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
874 cudnnConvolutionFwdAlgoPerf_t fwdAlgPerf;
875 Workspace workspace = chooseAlgorithm(args, benchmark, &fwdAlgPerf);
880 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), fwdAlgPerf.mathType));
882 Constant one(dataType, 1);
883 Constant zero(dataType, 0);
885 AT_CUDNN_CHECK(cudnnConvolutionForward(
887 &one, args.idesc.desc(), input.data_ptr(),
888 args.wdesc.desc(), weight.data_ptr(),
889 args.cdesc.desc(), fwdAlgPerf.algo, workspace.data, workspace.size,
890 &zero, args.odesc.desc(), output.data_ptr()));
893 Tensor cudnn_convolution_forward(
895 const TensorArg& input,
const TensorArg& weight,
896 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
897 bool benchmark,
bool deterministic)
899 checkAllSameType(c, {input, weight});
900 checkAllSameGPU(c, {input, weight});
902 auto output_t = at::empty(
903 conv_output_size(input->sizes(), weight->sizes(),
904 padding, stride, dilation, groups),
908 TensorArg output{ output_t,
"result", 0 };
909 convolution_shape_check(c, input, weight, output, padding, stride, dilation, groups);
912 Tensor weight_contig = weight->contiguous();
914 raw_cudnn_convolution_forward_out(
915 *output, *input, weight_contig,
916 padding, stride, dilation, groups, benchmark, deterministic);
923 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
924 int64_t groups,
bool benchmark,
bool deterministic)
926 TensorArg input { input_t,
"input", 1 },
927 weight { weight_t,
"weight", 2 },
928 bias { bias_t,
"bias", 3 };
929 setCuDNNStreamToCurrent();
930 CheckedFrom c =
"cudnn_convolution";
931 auto output_t = cudnn_convolution_forward(
932 c, input, weight, padding, stride, dilation, groups, benchmark, deterministic);
933 if (bias->defined()) {
934 cudnn_convolution_add_bias_(c, { output_t,
"result", 0 }, bias);
941 Tensor cudnn_convolution_transpose_backward_input(
943 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation,
944 int64_t groups,
bool benchmark,
bool deterministic)
946 TensorArg grad_output { grad_output_t,
"grad_output", 1 },
947 weight { weight_t,
"weight", 2 };
948 setCuDNNStreamToCurrent();
949 return cudnn_convolution_forward(
950 "cudnn_convolution_transpose_backward_input",
951 grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
954 std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_transpose_backward(
956 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
957 bool benchmark,
bool deterministic, std::array<bool,3> output_mask) {
959 Tensor grad_output = grad_output_t.contiguous();
961 Tensor grad_input, grad_weight, grad_bias;
962 if (output_mask[0]) {
963 grad_input = at::cudnn_convolution_transpose_backward_input(grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
965 if (output_mask[1]) {
966 grad_weight = at::cudnn_convolution_transpose_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
968 if (output_mask[2]) {
969 grad_bias = at::cudnn_convolution_backward_bias(grad_output);
972 return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
981 void raw_cudnn_convolution_backward_input_out(
985 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
986 bool benchmark,
bool deterministic) {
988 auto dataType = getCudnnDataType(grad_output);
990 ConvolutionArgs args{ grad_input, grad_output, weight };
991 args.handle = getCudnnHandle();
992 setConvolutionParams(&args.params, grad_input, weight, padding, stride, dilation, groups, deterministic);
993 args.idesc.set(grad_input);
994 args.wdesc.set(weight);
995 args.odesc.set(grad_output);
996 args.cdesc.set(dataType, grad_output.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
998 cudnnConvolutionBwdDataAlgoPerf_t bwdDataAlgPerf;
999 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdDataAlgPerf);
1004 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdDataAlgPerf.mathType));
1006 Constant one(dataType, 1);
1007 Constant zero(dataType, 0);
1009 AT_CUDNN_CHECK(cudnnConvolutionBackwardData(
1011 &one, args.wdesc.desc(), weight.data_ptr(),
1012 args.odesc.desc(), grad_output.data_ptr(),
1013 args.cdesc.desc(), bwdDataAlgPerf.algo, workspace.data, workspace.size,
1014 &zero, args.idesc.desc(), grad_input.data_ptr()));
1029 Tensor cudnn_convolution_backward_input(
1031 IntArrayRef input_size,
const TensorArg& grad_output,
const TensorArg& weight,
1032 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1033 bool benchmark,
bool deterministic)
1035 checkAllSameType(c, {grad_output, weight});
1036 checkAllSameGPU(c, {grad_output, weight});
1038 auto grad_input_t = at::empty(input_size, grad_output->options());
1041 TensorArg grad_input{ grad_input_t,
"result", 0 };
1042 convolution_shape_check(c, grad_input, weight, grad_output, padding, stride, dilation, groups);
1045 Tensor weight_contig = weight->contiguous();
1047 raw_cudnn_convolution_backward_input_out(
1048 *grad_input, *grad_output, weight_contig,
1049 padding, stride, dilation, groups, benchmark, deterministic);
1054 Tensor cudnn_convolution_transpose_forward(
1056 const TensorArg& grad_output,
const TensorArg& weight,
1057 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1058 bool benchmark,
bool deterministic)
1060 auto input_size = conv_input_size(grad_output->sizes(), weight->sizes(),
1061 padding, output_padding, stride, dilation, groups);
1062 return cudnn_convolution_backward_input(c, input_size, grad_output, weight,
1063 padding, stride, dilation, groups, benchmark, deterministic);
1066 Tensor cudnn_convolution_backward_input(
1067 IntArrayRef input_size,
const Tensor& grad_output_t,
const Tensor& weight_t,
1068 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1069 bool benchmark,
bool deterministic)
1071 TensorArg grad_output{ grad_output_t,
"grad_output", 1 },
1072 weight{ weight_t,
"weight", 2 };
1073 setCuDNNStreamToCurrent();
1074 return cudnn_convolution_backward_input(
1075 "cudnn_convolution_backward_input",
1076 input_size, grad_output, weight,
1077 padding, stride, dilation, groups, benchmark, deterministic);
1080 std::tuple<at::Tensor,at::Tensor,at::Tensor> cudnn_convolution_backward(
1082 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1083 bool benchmark,
bool deterministic, std::array<bool,3> output_mask) {
1085 Tensor grad_output = grad_output_t.contiguous();
1087 Tensor grad_input, grad_weight, grad_bias;
1088 if (output_mask[0]) {
1089 grad_input = at::cudnn_convolution_backward_input(input.sizes(), grad_output, weight, padding, stride, dilation, groups, benchmark, deterministic);
1091 if (output_mask[1]) {
1092 grad_weight = at::cudnn_convolution_backward_weight(weight.sizes(), grad_output, input, padding, stride, dilation, groups, benchmark, deterministic);
1094 if (output_mask[2]) {
1095 grad_bias = at::cudnn_convolution_backward_bias(grad_output);
1098 return std::tuple<Tensor,Tensor,Tensor>{grad_input, grad_weight, grad_bias};
1101 Tensor cudnn_convolution_transpose(
1103 IntArrayRef padding, IntArrayRef output_padding, IntArrayRef stride, IntArrayRef dilation,
1104 int64_t groups,
bool benchmark,
bool deterministic)
1106 TensorArg input { input_t,
"input", 1 },
1107 weight { weight_t,
"weight", 2 },
1108 bias { bias_t,
"bias", 3 };
1109 CheckedFrom c =
"cudnn_convolution_transpose";
1110 auto output_t = cudnn_convolution_transpose_forward(
1111 c, input, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic);
1112 if (bias->defined()) {
1113 cudnn_convolution_add_bias_(c, { output_t,
"result", 0 }, bias);
1124 void raw_cudnn_convolution_backward_weight_out(
1126 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1127 bool benchmark,
bool deterministic) {
1129 auto dataType = getCudnnDataType(input);
1131 ConvolutionArgs args{ input, grad_output, grad_weight };
1132 args.handle = getCudnnHandle();
1133 setConvolutionParams(&args.params, input, grad_weight, padding, stride, dilation, groups, deterministic);
1134 args.idesc.set(input);
1135 args.wdesc.set(grad_weight);
1136 args.odesc.set(grad_output);
1137 args.cdesc.set(dataType, input.dim() - 2, args.params.padding, args.params.stride, args.params.dilation, args.params.groups);
1139 cudnnConvolutionBwdFilterAlgoPerf_t bwdFilterAlgPerf;
1140 Workspace workspace = chooseAlgorithm(args, benchmark, &bwdFilterAlgPerf);
1145 AT_CUDNN_CHECK(cudnnSetConvolutionMathType(args.cdesc.mut_desc(), bwdFilterAlgPerf.mathType));
1147 Constant one(dataType, 1);
1148 Constant zero(dataType, 0);
1150 AT_CUDNN_CHECK(cudnnConvolutionBackwardFilter(
1152 &one, args.idesc.desc(), input.data_ptr(),
1153 args.odesc.desc(), grad_output.data_ptr(),
1154 args.cdesc.desc(), bwdFilterAlgPerf.algo, workspace.data, workspace.size,
1155 &zero, args.wdesc.desc(), grad_weight.data_ptr()));
1158 Tensor cudnn_convolution_backward_weight(
1160 IntArrayRef weight_size,
const TensorArg& grad_output,
const TensorArg& input,
1161 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1162 bool benchmark,
bool deterministic)
1165 checkAllSameType(c, {grad_output, input});
1166 checkAllSameGPU(c, {grad_output, input});
1168 auto grad_weight_t = at::empty(weight_size, grad_output->options());
1172 TensorArg grad_weight{ grad_weight_t,
"result", 0 };
1173 convolution_shape_check(c, input, grad_weight, grad_output, padding, stride, dilation, groups);
1175 raw_cudnn_convolution_backward_weight_out(
1176 *grad_weight, *grad_output, *input,
1177 padding, stride, dilation, groups, benchmark, deterministic);
1179 return grad_weight_t;
1182 Tensor cudnn_convolution_backward_weight(
1183 IntArrayRef weight_size,
1184 const Tensor& grad_output_t,
1186 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1187 bool benchmark,
bool deterministic)
1189 TensorArg grad_output{ grad_output_t,
"grad_output", 1 },
1190 input{ input_t,
"input", 2 };
1191 setCuDNNStreamToCurrent();
1192 return cudnn_convolution_backward_weight(
1193 "cudnn_convolution_backward_weight",
1194 weight_size, grad_output, input,
1195 padding, stride, dilation, groups, benchmark, deterministic);
1198 Tensor cudnn_convolution_transpose_backward_weight(
1199 IntArrayRef weight_size,
1200 const Tensor& grad_output_t,
1202 IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation, int64_t groups,
1203 bool benchmark,
bool deterministic)
1205 TensorArg grad_output{ grad_output_t,
"grad_output", 1 },
1206 input{ input_t,
"input", 2 };
1207 setCuDNNStreamToCurrent();
1208 return cudnn_convolution_backward_weight(
1209 "cudnn_convolution_backward_weight",
1210 weight_size, input, grad_output,
1211 padding, stride, dilation, groups, benchmark, deterministic);
1220 Tensor cudnn_convolution_backward_bias(
1221 const Tensor& grad_output_t)
1223 TensorArg grad_output{ grad_output_t,
"grad_output", 1 };
1224 setCuDNNStreamToCurrent();
1226 auto grad_bias_t = at::empty(
1227 { grad_output->size(output_channels_dim) }, grad_output->options());
1229 TensorArg grad_bias{ grad_bias_t,
"result", 0 };
1233 TensorDescriptor bdesc{grad_bias->expand({1, grad_bias->size(0)}),
1234 static_cast<size_t>(grad_output->dim())};
1235 TensorDescriptor odesc{*grad_output};
1237 auto handle = getCudnnHandle();
1238 auto dataType = getCudnnDataType(*grad_bias);
1239 Constant one(dataType, 1);
1240 Constant zero(dataType, 0);
1242 AT_CUDNN_CHECK(cudnnConvolutionBackwardBias(handle, &one, odesc.desc(), grad_output->data_ptr(),
1243 &zero, bdesc.desc(), grad_bias->data_ptr()));
Flush-To-Zero and Denormals-Are-Zero mode.