Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_op_rtc_gpu.cc
1 #include <cstdio>
2 
3 #include "caffe2/core/common_gpu.h"
4 #include "caffe2/core/context_gpu.h"
5 #include "caffe2/operators/pool_op.h"
6 #include "caffe2/cuda_rtc/common_rtc.h"
7 
8 namespace caffe2 {
9 namespace {
10 class AveragePool {};
11 class MaxPool {};
12 } // namespace
13 
14 namespace {
15 
16 // The max pool forward function, with parameters written in const int.
17 const char kMaxPoolForwardNCHWSource[] = R"(
18 extern "C"
19 __global__ void %s(const float* bottom_data, float* top_data) {
20  const int nthreads = %d;
21  const int channels = %d;
22  const int height = %d;
23  const int width = %d;
24  const int pooled_height = %d;
25  const int pooled_width = %d;
26  const int kernel_h = %d;
27  const int kernel_w = %d;
28  const int stride_h = %d;
29  const int stride_w = %d;
30  const int pad_t = %d;
31  const int pad_l = %d;
32  for (int index = blockIdx.x * blockDim.x + threadIdx.x;
33  index < nthreads; index += blockDim.x * gridDim.x) {
34  int pw = index %% pooled_width;
35  int ph = (index / pooled_width) %% pooled_height;
36  int c = (index / (pooled_width * pooled_height)) %% channels;
37  int n = index / (pooled_width * pooled_height * channels);
38  int hstart = ph * stride_h - pad_t;
39  int wstart = pw * stride_w - pad_l;
40  int hend = min(hstart + kernel_h, height);
41  int wend = min(wstart + kernel_w, width);
42  hstart = max(hstart, 0);
43  wstart = max(wstart, 0);
44  float maxval = -1.0e37f;
45  const float* bdata_offset = bottom_data + n * channels * height * width;
46  for (int h = hstart; h < hend; ++h) {
47  for (int w = wstart; w < wend; ++w) {
48  maxval = fmaxf(
49  bdata_offset[c * height * width + h * width + w], maxval);
50  }
51  }
52  top_data[index] = maxval;
53  }
54 }
55 )";
56 
57 // The max pool forward function, with parameters written in const int.
58 const char kMaxPoolBackwardNCHWSource[] = R"(
59 extern "C"
60 __global__ void %s(
61  const float* const bottom_data, const float* const top_data,
62  const float* const top_diff, float* const bottom_diff) {
63  const int nthreads = %d;
64  const int num = %d;
65  const int channels = %d;
66  const int height = %d;
67  const int width = %d;
68  const int pooled_height = %d;
69  const int pooled_width = %d;
70  const int kernel_h = %d;
71  const int kernel_w = %d;
72  const int stride_h = %d;
73  const int stride_w = %d;
74  const int pad_t = %d;
75  const int pad_l = %d;
76  for (int index = blockIdx.x * blockDim.x + threadIdx.x;
77  index < nthreads; index += blockDim.x * gridDim.x) {
78  const int w = index %% width + pad_l;
79  const int h = (index / width) %% height + pad_t;
80  const int c = (index / width / height) %% channels;
81  const int n = index / width / height / channels;
82  const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1;
83  const int phend = min(h / stride_h + 1, pooled_height);
84  const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1;
85  const int pwend = min(w / stride_w + 1, pooled_width);
86  const int top_offset =
87  (n * channels + c) * pooled_height * pooled_width;
88  bottom_diff[index] = 0;
89  for (int ph = phstart; ph < phend; ++ph) {
90  for (int pw = pwstart; pw < pwend; ++pw) {
91  int top_local_offset = top_offset + ph * pooled_width + pw;
92  if (bottom_data[index] == top_data[top_local_offset]) {
93  bottom_diff[index] += top_diff[top_local_offset];
94  }
95  }
96  }
97  }
98 }
99 )";
100 
101 
102 class MaxPoolRTCFunction : public CudaRTCFunction<MaxPoolRTCFunction> {
103  public:
104  MaxPoolRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
105 
106  template <typename... Args>
107  string KernelName(Args... /*args*/) {
108  return name_;
109  }
110 
111  template <typename... Args>
112  string GetSource(Args... args);
113 
114  private:
115  string name_;
116 };
117 
118 class MaxPoolGradientRTCFunction
119  : public CudaRTCFunction<MaxPoolGradientRTCFunction> {
120  public:
121  MaxPoolGradientRTCFunction() : CudaRTCFunction(), name_(GetUniqueName()) {}
122 
123  template <typename... Args>
124  string KernelName(Args... /*args*/) {
125  return name_;
126  }
127 
128  template <typename... Args>
129  string GetSource(Args... args);
130 
131  private:
132  string name_;
133 };
134 
135 
136 template <>
137 string MaxPoolRTCFunction::GetSource(
138  const int output_size,
139  const int channels,
140  const int height,
141  const int width,
142  const int pooled_height,
143  const int pooled_width,
144  const int kernel_h,
145  const int kernel_w,
146  const int stride_h,
147  const int stride_w,
148  const int pad_t,
149  const int pad_l) {
150  char buffer[65536];
151  int nbytes = snprintf(
152  buffer, 65536, kMaxPoolForwardNCHWSource, name_.c_str(), output_size,
153  channels, height, width, pooled_height, pooled_width, kernel_h, kernel_w,
154  stride_h, stride_w, pad_t, pad_l);
155  DCHECK_GE(nbytes, 0);
156  DCHECK_LT(nbytes, 65536);
157  return string(buffer);
158 }
159 
160 template <>
161 string MaxPoolGradientRTCFunction::GetSource(
162  const int output_size,
163  const int num,
164  const int channels,
165  const int height,
166  const int width,
167  const int pooled_height,
168  const int pooled_width,
169  const int kernel_h,
170  const int kernel_w,
171  const int stride_h,
172  const int stride_w,
173  const int pad_t,
174  const int pad_l) {
175  char buffer[65536];
176  int nbytes = snprintf(
177  buffer, 65536, kMaxPoolBackwardNCHWSource, name_.c_str(), output_size,
178  num, channels, height, width, pooled_height, pooled_width, kernel_h,
179  kernel_w, stride_h, stride_w, pad_t, pad_l);
180  DCHECK_GE(nbytes, 0);
181  DCHECK_LT(nbytes, 65536);
182  return string(buffer);
183 }
184 
185 } // namespace
186 
187 
188 class MaxPoolRTCOp final : public ConvPoolOpBase<CUDAContext> {
189  public:
190  MaxPoolRTCOp(const OperatorDef& operator_def, Workspace* ws)
191  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {
192  CAFFE_ENFORCE_EQ(
193  order_, StorageOrder::NCHW, "Currently only NCHW is supported.");
194  }
195  ~MaxPoolRTCOp() override {}
196 
197  bool RunOnDeviceWithOrderNCHW() override {
198  auto& X = Input(0);
199  auto output_sizes = ConvPoolOpBase<CUDAContext>::GetOutputSize(X, X.dim32(1));
200  auto* Y = Output(0, output_sizes, at::dtype<float>());
201 
202  if (input_dims_ != X.sizes()) {
203  // recompile
204  VLOG(1) << "MaxPool RTC recompiling";
205  CAFFE_ENFORCE_LT(Y->numel(), std::numeric_limits<int>::max());
206  func_.Compile(
207  static_cast<int>(Y->numel()),
208  X.dim32(1),
209  X.dim32(2),
210  X.dim32(3),
211  Y->dim32(2),
212  Y->dim32(3),
213  kernel_h(),
214  kernel_w(),
215  stride_h(),
216  stride_w(),
217  pad_t(),
218  pad_l());
219  input_dims_ = X.sizes().vec();
220  }
221  // Carry out the pooling computation.
222  func_.Launch(
223  CAFFE_GET_BLOCKS(Y->numel()),
224  1,
225  1,
226  CAFFE_CUDA_NUM_THREADS,
227  1,
228  1,
229  0,
230  context_.cuda_stream(),
231  X.data<float>(),
232  Y->mutable_data<float>());
233  return true;
234  }
235 
236  bool RunOnDeviceWithOrderNHWC() override {
237  LOG(FATAL) << "Not implemented.";
238  return false;
239  }
240 
241  private:
242  MaxPoolRTCFunction func_;
243  vector<int64_t> input_dims_;
244 };
245 
246 class MaxPoolGradientRTCOp final : public ConvPoolOpBase<CUDAContext> {
247  public:
248  MaxPoolGradientRTCOp(const OperatorDef& operator_def, Workspace* ws)
249  : ConvPoolOpBase<CUDAContext>(operator_def, ws) {
250  CAFFE_ENFORCE_EQ(
251  order_, StorageOrder::NCHW, "Currently only NCHW is supported.");
252  }
253  ~MaxPoolGradientRTCOp() override {}
254 
255  bool RunOnDeviceWithOrderNCHW() override {
256  auto& X = Input(0);
257  auto& Y = Input(1);
258  auto& dY = Input(2);
259  CAFFE_ENFORCE_EQ(dY.dim(), 4);
260 
261  auto* dX = Output(0, X.sizes(), at::dtype<float>());
262  ConvPoolOpBase<CUDAContext>::ComputePads({X.dim32(2), X.dim32(3)});
263  if (input_dims_ != X.sizes()) {
264  VLOG(1) << "MaxPoolGradient RTC recompiling";
265  CAFFE_ENFORCE_LT(X.numel(), std::numeric_limits<int>::max());
266  func_.Compile(
267  static_cast<int>(X.numel()),
268  X.dim32(0),
269  X.dim32(1),
270  X.dim32(2),
271  X.dim32(3),
272  dY.dim32(2),
273  dY.dim32(3),
274  kernel_h(),
275  kernel_w(),
276  stride_h(),
277  stride_w(),
278  pad_t(),
279  pad_l());
280  input_dims_ = X.sizes().vec();
281  }
282  func_.Launch(
283  CAFFE_GET_BLOCKS(X.numel()),
284  1,
285  1,
286  CAFFE_CUDA_NUM_THREADS,
287  1,
288  1,
289  0,
290  context_.cuda_stream(),
291  X.data<float>(),
292  Y.data<float>(),
293  dY.data<float>(),
294  dX->mutable_data<float>());
295  return true;
296  }
297 
298  bool RunOnDeviceWithOrderNHWC() override {
299  LOG(FATAL) << "Not implemented.";
300  return false;
301  }
302 
303  private:
304  MaxPoolGradientRTCFunction func_;
305  vector<int64_t> input_dims_;
306 };
307 
308 namespace {
309 REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPool, NVRTC, MaxPoolRTCOp);
310 REGISTER_CUDA_OPERATOR_WITH_ENGINE(MaxPoolGradient, NVRTC,
312 } // namespace
313 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
int CAFFE_GET_BLOCKS(const int N)
Compute the number of blocks needed to run N threads.
Definition: common_gpu.h:340