2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 4 #include <ATen/cuda/CUDAConfig.h> 6 #if !AT_CUDNN_ENABLED() 8 namespace at {
namespace native {
12 Tensor cudnn_grid_sampler_forward(
14 AT_ERROR(
"cudnn_grid_sampler_forward: ATen not compiled with cuDNN support");
17 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
19 const Tensor& grad_output_t) {
20 AT_ERROR(
"cudnn_grid_sampler_backward: ATen not compiled with cuDNN support");
25 #else // AT_CUDNN_ENABLED 27 #include <ATen/cudnn/Descriptors.h> 28 #include <ATen/cudnn/Types.h> 29 #include <ATen/cudnn/Utils.h> 30 #include <ATen/cuda/Exceptions.h> 32 #include <ATen/TensorUtils.h> 37 namespace at {
namespace native {
41 void setSamplerDescriptor(SpatialTransformerDescriptor& desc, cudnnDataType_t dataType,
const at::Tensor& tensor)
43 int inputSize[4] = {0};
44 for (
int i = 0; i < tensor.dim(); ++i) {
45 inputSize[i] = (int) tensor.size(i);
47 desc.set(dataType, 4, inputSize);
50 void checkGridSize(CheckedFrom c, TensorArg grid, TensorArg input)
56 checkContiguous(c, grid);
60 checkSize(c, grid, 0, input->size(0));
61 checkSize(c, grid, 3, 2);
66 Tensor cudnn_grid_sampler_forward(
69 TensorArg input{ contiguousIfZeroInStrides(input_t),
"input", 1 },
70 grid{ grid_t.contiguous(),
"grid", 2 };
71 CheckedFrom c =
"cudnn_grid_sampler_forward";
72 setCuDNNStreamToCurrent();
73 checkAllSameGPU(c, {input, grid});
74 checkAllSameType(c, {input, grid});
75 checkGridSize(c, grid, input);
76 checkDim(c, input, 4);
78 auto output_t = at::empty({0}, input->options());
79 output_t.resize_({input->size(0), input->size(1), grid->size(1), grid->size(2)});
81 TensorDescriptor idesc{ *input };
82 TensorDescriptor odesc{ output_t };
83 SpatialTransformerDescriptor desc;
85 auto handle = getCudnnHandle();
86 auto dataType = getCudnnDataType(*input);
87 setSamplerDescriptor(desc, dataType, output_t);
89 Constant one(dataType, 1);
90 Constant zero(dataType, 0);
91 AT_CUDNN_CHECK(cudnnSpatialTfSamplerForward(
93 &one, idesc.desc(), input->data_ptr(),
95 &zero, odesc.desc(), output_t.data_ptr()
103 std::tuple<Tensor, Tensor> cudnn_grid_sampler_backward(
105 const Tensor& grad_output_t)
107 TensorArg input{ contiguousIfZeroInStrides(input_t),
"input", 1 },
108 grid{ grid_t.contiguous(),
"grid", 2 },
109 grad_output{ contiguousIfZeroInStrides(grad_output_t),
"grad_output", 3 };
110 CheckedFrom c =
"cudnn_grid_sampler_backward";
111 setCuDNNStreamToCurrent();
112 checkAllSameGPU(c, {input, grad_output, grid});
113 checkGridSize(c, grid, input);
114 checkDim(c, input, 4);
115 checkDim(c, grad_output, 4);
117 auto grad_input_t = at::empty({0}, input->options());
118 grad_input_t.resize_(input->sizes());
119 auto grad_grid_t = at::empty({0}, grid->options());
120 grad_grid_t.resize_(grid->sizes());
122 TensorDescriptor idesc{ *input };
123 TensorDescriptor odesc{ *grad_output };
124 TensorDescriptor gdesc{ grad_input_t };
125 SpatialTransformerDescriptor desc;
127 auto handle = getCudnnHandle();
128 auto dataType = getCudnnDataType(*input);
129 setSamplerDescriptor(desc, dataType, *grad_output);
131 Constant one(dataType, 1);
132 Constant zero(dataType, 0);
133 AT_CUDNN_CHECK(cudnnSpatialTfSamplerBackward(
135 &one, idesc.desc(), input->data_ptr(),
136 &zero, gdesc.desc(), grad_input_t.data_ptr(),
137 &one, odesc.desc(), grad_output->data_ptr(),
140 &zero, grad_grid_t.data_ptr()
143 return std::tuple<Tensor, Tensor>{ grad_input_t, grad_grid_t };
Flush-To-Zero and Denormals-Are-Zero mode.