Caffe2 - C++ API
A deep learning, cross platform ML framework
AffineGridGenerator.cpp
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/Config.h>
4 #include <ATen/cuda/CUDAConfig.h>
5 
6 #if !AT_CUDNN_ENABLED()
7 
8 namespace at { namespace native {
9 
10 // See Note [ATen preprocessor philosophy]
11 
12 Tensor cudnn_affine_grid_generator_forward(
13  const Tensor& theta,
14  int64_t N, int64_t C, int64_t H, int64_t W) {
15  AT_ERROR("cudnn_affine_grid_generator_forward: ATen not compiled with cuDNN support");
16 }
17 
18 Tensor cudnn_affine_grid_generator_backward(
19  const Tensor& grad_theta,
20  int64_t N, int64_t C, int64_t H, int64_t W) {
21  AT_ERROR("cudnn_affine_grid_generator_backward: ATen not compiled with cuDNN support");
22 }
23 
24 }}
25 
26 #else // AT_CUDNN_ENABLED()
27 
28 #include <ATen/cudnn/cudnn-wrapper.h>
29 #include <ATen/cudnn/Handle.h>
30 #include <ATen/cudnn/Descriptors.h>
31 #include <ATen/cudnn/Types.h>
32 #include <ATen/cudnn/Utils.h>
33 #include <ATen/cuda/Exceptions.h>
34 
35 #include <ATen/TensorUtils.h>
36 
37 namespace at { namespace native {
38 
39 namespace {
40 
41 void setSamplerDescriptor(SpatialTransformerDescriptor& desc,
42  cudnnDataType_t dataType,
43  int N, int C, int H, int W)
44 {
45  int inputSize[4] = {N, C, H, W};
46  desc.set(dataType, 4, inputSize);
47 }
48 
49 } // namespace
50 
51 Tensor cudnn_affine_grid_generator_forward(
52  const Tensor& theta_t,
53  int64_t N, int64_t C, int64_t H, int64_t W)
54 {
55  setCuDNNStreamToCurrent();
56 
57  TensorArg theta{ theta_t.contiguous(), "theta", 1 };
58  CheckedFrom c = "cudnn_affine_grid_generator_forward";
59  checkContiguous(c, theta);
60  checkSize(c, theta, {N, 2, 3});
61 
62  auto grid_t = at::empty({0}, theta->options());
63  grid_t.resize_({N, H, W, 2});
64 
65  auto dataType = getCudnnDataType(*theta);
66  SpatialTransformerDescriptor desc;
67  setSamplerDescriptor(desc, dataType, N, C, H, W);
68  AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorForward(getCudnnHandle(), desc.desc(),
69  theta->data_ptr(),
70  grid_t.data_ptr()));
71  return grid_t;
72 }
73 
74 Tensor cudnn_affine_grid_generator_backward(
75  const Tensor& grad_grid_t,
76  int64_t N, int64_t C, int64_t H, int64_t W)
77 {
78  setCuDNNStreamToCurrent();
79 
80  TensorArg grad_grid{ grad_grid_t.contiguous(), "grad_grid", 1 };
81  CheckedFrom c = "cudnn_affine_grid_generator_backward";
82  checkContiguous(c, grad_grid);
83  checkSize(c, grad_grid, {N, H, W, 2});
84 
85  auto grad_theta_t = at::empty({0}, grad_grid->options());
86  grad_theta_t.resize_({N, 2, 3});
87 
88  auto dataType = getCudnnDataType(grad_theta_t);
89  SpatialTransformerDescriptor desc;
90  setSamplerDescriptor(desc, dataType, N, C, H, W);
91  AT_CUDNN_CHECK(cudnnSpatialTfGridGeneratorBackward(getCudnnHandle(), desc.desc(),
92  grad_grid->data_ptr(),
93  grad_theta_t.data_ptr()));
94  return grad_theta_t;
95 }
96 
97 }} // namespace at::native
98 
99 #endif // AT_CUDNN_ENABLED()
Definition: static.cpp:64
Flush-To-Zero and Denormals-Are-Zero mode.