2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 4 #include <ATen/cuda/CUDAConfig.h> 6 #if !AT_CUDNN_ENABLED() 8 namespace at {
namespace native {
12 Tensor cudnn_affine_grid_generator_forward(
14 int64_t N, int64_t
C, int64_t H, int64_t W) {
15 AT_ERROR(
"cudnn_affine_grid_generator_forward: ATen not compiled with cuDNN support");
18 Tensor cudnn_affine_grid_generator_backward(
20 int64_t N, int64_t
C, int64_t H, int64_t W) {
21 AT_ERROR(
"cudnn_affine_grid_generator_backward: ATen not compiled with cuDNN support");
26 #else // AT_CUDNN_ENABLED() 28 #include <ATen/cudnn/cudnn-wrapper.h> 29 #include <ATen/cudnn/Handle.h> 30 #include <ATen/cudnn/Descriptors.h> 31 #include <ATen/cudnn/Types.h> 32 #include <ATen/cudnn/Utils.h> 33 #include <ATen/cuda/Exceptions.h> 35 #include <ATen/TensorUtils.h> 37 namespace at {
namespace native {
41 void setSamplerDescriptor(SpatialTransformerDescriptor& desc,
42 cudnnDataType_t dataType,
43 int N,
int C,
int H,
int W)
45 int inputSize[4] = {N, C, H, W};
46 desc.set(dataType, 4, inputSize);
51 Tensor cudnn_affine_grid_generator_forward(
53 int64_t N, int64_t
C, int64_t H, int64_t W)
55 setCuDNNStreamToCurrent();
57 TensorArg theta{ theta_t.contiguous(),
"theta", 1 };
58 CheckedFrom c =
"cudnn_affine_grid_generator_forward";
59 checkContiguous(c, theta);
60 checkSize(c, theta, {N, 2, 3});
62 auto grid_t = at::empty({0}, theta->options());
63 grid_t.resize_({N, H, W, 2});
65 auto dataType = getCudnnDataType(*theta);
66 SpatialTransformerDescriptor desc;
67 setSamplerDescriptor(desc, dataType, N, C, H, W);
68 AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorForward(getCudnnHandle(), desc.desc(),
74 Tensor cudnn_affine_grid_generator_backward(
76 int64_t N, int64_t C, int64_t H, int64_t W)
78 setCuDNNStreamToCurrent();
80 TensorArg grad_grid{ grad_grid_t.contiguous(),
"grad_grid", 1 };
81 CheckedFrom c =
"cudnn_affine_grid_generator_backward";
82 checkContiguous(c, grad_grid);
83 checkSize(c, grad_grid, {N, H, W, 2});
85 auto grad_theta_t = at::empty({0}, grad_grid->options());
86 grad_theta_t.resize_({N, 2, 3});
88 auto dataType = getCudnnDataType(grad_theta_t);
89 SpatialTransformerDescriptor desc;
90 setSamplerDescriptor(desc, dataType, N, C, H, W);
91 AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorBackward(getCudnnHandle(), desc.desc(),
92 grad_grid->data_ptr(),
93 grad_theta_t.data_ptr()));
99 #endif // AT_CUDNN_ENABLED()
Flush-To-Zero and Denormals-Are-Zero mode.