Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_gradient_op.cc
1 
17 #include "caffe2/operators/pool_op.h"
18 
19 namespace caffe2 {
20 
21 using std::max;
22 using std::min;
23 
24 namespace {
25 // These two classe are just used as template arguments passed to the
26 // PoolGradientOp
27 // template to instantiate the different algorithms.
28 template <typename T>
29 class AveragePool {
30  public:
31  static void process_grad(
32  const T& /*x_data*/,
33  const T& /*y_data*/,
34  const T& dy_data,
35  const T& scale,
36  T& dx_data) {
37  dx_data += (scale * dy_data);
38  }
39 
40  static void process_grad(
41  const int y_col,
42  const int x_col,
43  const float scale,
44  ConstEigenArrayMap<float>& /*x_data*/,
45  ConstEigenArrayMap<float>& /*y_data*/,
46  ConstEigenArrayMap<float>& dy_data,
47  EigenArrayMap<float>& dx_data) {
48  dx_data.col(x_col) += scale * dy_data.col(y_col);
49  }
50 };
51 
52 template <typename T>
53 class MaxPool {
54  public:
55  static void process_grad(
56  const T& x_data,
57  const T& y_data,
58  const T& dy_data,
59  const T& /*scale*/,
60  T& dx_data) {
61  if (x_data == y_data) {
62  dx_data += dy_data;
63  }
64  }
65 
66  static void process_grad(
67  const int y_col,
68  const int x_col,
69  const float /*scale*/,
70  ConstEigenArrayMap<float>& x_data,
71  ConstEigenArrayMap<float>& y_data,
72  ConstEigenArrayMap<float>& dy_data,
73  EigenArrayMap<float>& dx_data) {
74  dx_data.col(x_col) +=
75  dy_data.col(y_col) * (x_data.col(x_col)
76  .cwiseEqual(y_data.col(y_col))
77  .template cast<float>());
78  }
79 };
80 }
81 
82 template <typename T, class Context, typename PoolType>
83 bool PoolGradientOp<T, Context, PoolType>::RunOnDeviceWithOrderNCHW() {
84  auto& X = Input(0);
85  auto& Y = Input(1);
86  auto& dY = Input(2);
87  auto* dX = Output(0);
88  // TODO(Yangqing): Add shape checks.
89  dX->ResizeLike(X);
90  math::Set<float, CPUContext>(
91  X.size(), 0, dX->template mutable_data<float>(), &context_);
92  const float* Xdata = X.template data<float>();
93  const float* Ydata = Y.template data<float>();
94  const float* dYdata = dY.template data<float>();
95  float* dXdata = dX->template mutable_data<float>();
96  int channels = X.dim32(1);
97  CAFFE_ENFORCE_EQ(channels, dY.dim32(1));
98  int height = X.dim32(2);
99  int width = kernel_.size() > 1 ? X.dim32(3) : 1;
100  int depth = kernel_.size() > 2 ? X.dim32(4) : 1;
101  vector<int> dims(X.dims().begin() + 2, X.dims().end());
102  ConvPoolOpBase<CPUContext>::ComputePads(dims);
103  int pooled_height = dY.dim32(2);
104  int pooled_width = kernel_.size() > 1 ? dY.dim32(3) : 1;
105  int pooled_depth = kernel_.size() > 2 ? dY.dim32(4) : 1;
106  // The main loop
107  switch (kernel_.size()) {
108  case 1:
109  for (int n = 0; n < X.dim32(0); ++n) {
110  for (int c = 0; c < channels; ++c) {
111  for (int ph = 0; ph < pooled_height; ++ph) {
112  int hstart = ph * stride_h() - pad_t();
113  int hend = min(hstart + kernel_h(), height);
114  hstart = max(hstart, 0);
115  float scale = 1. / (hend - hstart);
116  for (int h = hstart; h < hend; ++h) {
117  PoolType::process_grad(
118  Xdata[h], Ydata[ph], dYdata[ph], scale, dXdata[h]);
119  }
120  }
121  // offset
122  Xdata += height;
123  dXdata += height;
124  Ydata += pooled_height;
125  dYdata += pooled_height;
126  }
127  }
128  break;
129  case 2:
130  for (int n = 0; n < X.dim32(0); ++n) {
131  for (int c = 0; c < channels; ++c) {
132  for (int ph = 0; ph < pooled_height; ++ph) {
133  int hstart = ph * stride_h() - pad_t();
134  int hend = min(hstart + kernel_h(), height);
135  hstart = max(hstart, 0);
136  for (int pw = 0; pw < pooled_width; ++pw) {
137  int wstart = pw * stride_w() - pad_l();
138  int wend = min(wstart + kernel_w(), width);
139  wstart = max(wstart, 0);
140  float scale = 1. / (hend - hstart) / (wend - wstart);
141  const int pooled_index = ph * pooled_width + pw;
142  for (int h = hstart; h < hend; ++h) {
143  for (int w = wstart; w < wend; ++w) {
144  const int index = h * width + w;
145  PoolType::process_grad(
146  Xdata[index],
147  Ydata[pooled_index],
148  dYdata[pooled_index],
149  scale,
150  dXdata[index]);
151  }
152  }
153  }
154  }
155  // offset
156  Xdata += height * width;
157  dXdata += height * width;
158  Ydata += pooled_height * pooled_width;
159  dYdata += pooled_height * pooled_width;
160  }
161  }
162  break;
163  case 3:
164  for (int n = 0; n < X.dim32(0); ++n) {
165  for (int c = 0; c < channels; ++c) {
166  for (int ph = 0; ph < pooled_height; ++ph) {
167  int hstart = ph * stride_h() - pad_t();
168  int hend = min(hstart + kernel_h(), height);
169  hstart = max(hstart, 0);
170  for (int pw = 0; pw < pooled_width; ++pw) {
171  int wstart = pw * stride_w() - pad_l();
172  int wend = min(wstart + kernel_w(), width);
173  wstart = max(wstart, 0);
174  for (int pd = 0; pd < pooled_depth; ++pd) {
175  int dstart = pd * stride_[2] - pads_[2];
176  int dend = min(dstart + kernel_[2], depth);
177  dstart = max(dstart, 0);
178  float scale =
179  1. / (hend - hstart) / (wend - wstart) / (dend - dstart);
180  const int pooled_index =
181  ph * pooled_width * pooled_depth + pw * pooled_depth + pd;
182  for (int h = hstart; h < hend; ++h) {
183  for (int w = wstart; w < wend; ++w) {
184  for (int d = dstart; d < dend; ++d) {
185  const int index = h * width * depth + w * depth + d;
186  PoolType::process_grad(
187  Xdata[index],
188  Ydata[pooled_index],
189  dYdata[pooled_index],
190  scale,
191  dXdata[index]);
192  }
193  }
194  }
195  }
196  }
197  }
198  // offset
199  Xdata += height * width * depth;
200  dXdata += height * width * depth;
201  Ydata += pooled_height * pooled_width * pooled_depth;
202  dYdata += pooled_height * pooled_width * pooled_depth;
203  }
204  }
205  break;
206  default:
207  CAFFE_THROW("Unsupported pooling size");
208  return false;
209  }
210  return true;
211 }
212 
213 template <typename T, class Context, typename PoolType>
214 bool PoolGradientOp<T, Context, PoolType>::RunOnDeviceWithOrderNHWC() {
215  auto& X = Input(0);
216  auto& Y = Input(1);
217  auto& dY = Input(2);
218  DCHECK_EQ(dY.ndim(), kernel_.size() + 2);
219  auto* dX = Output(0);
220  dX->ResizeLike(X);
221 
222  int channels = X.dim32(X.ndim() - 1);
223  CAFFE_ENFORCE_EQ(channels, dY.dim32(dY.ndim() - 1));
224  ConstEigenArrayMap<T> Ymat(
225  Y.template data<float>(), channels, Y.size() / channels);
226  ConstEigenArrayMap<float> dYmat(
227  dY.template data<float>(), channels, Y.size() / channels);
228  ConstEigenArrayMap<float> Xmat(
229  X.template data<float>(), channels, X.size() / channels);
230  EigenArrayMap<float> dXmat(
231  dX->template mutable_data<float>(), channels, X.size() / channels);
232  dXmat.setZero();
233  int height = X.dim32(1);
234  int width = kernel_.size() > 1 ? X.dim32(2) : 1;
235  int depth = kernel_.size() > 2 ? X.dim32(3) : 1;
236  vector<int> dims(X.dims().begin() + 1, X.dims().end() - 1);
237  ConvPoolOpBase<CPUContext>::ComputePads(dims);
238  int pooled_height = dY.dim32(1);
239  int pooled_width = kernel_.size() > 1 ? dY.dim32(2) : 1;
240  int pooled_depth = kernel_.size() > 2 ? dY.dim32(3) : 1;
241 
242  // The main loop
243  // Do not do openmp here: the following for loops are looping over the pooled
244  // output, so if one parallelizes the outer loops, race conditions could
245  // happen in the inner loops.
246  switch (kernel_.size()) {
247  case 1:
248  for (int n = 0; n < X.dim32(0); ++n) {
249  for (int ph = 0; ph < pooled_height; ++ph) {
250  int hstart = ph * stride_h() - pad_t();
251  int hend = min(hstart + kernel_h(), height);
252  hstart = max(hstart, 0);
253  const int pool_index = n * pooled_height + ph;
254  const float scale = 1. / (hend - hstart);
255  for (int h = hstart; h < hend; ++h) {
256  const int input_index = n * height + h;
257  PoolType::process_grad(
258  pool_index, input_index, scale, Xmat, Ymat, dYmat, dXmat);
259  }
260  }
261  }
262  break;
263  case 2:
264  for (int n = 0; n < X.dim32(0); ++n) {
265  for (int ph = 0; ph < pooled_height; ++ph) {
266  int hstart = ph * stride_h() - pad_t();
267  int hend = min(hstart + kernel_h(), height);
268  hstart = max(hstart, 0);
269  for (int pw = 0; pw < pooled_width; ++pw) {
270  int wstart = pw * stride_w() - pad_l();
271  int wend = min(wstart + kernel_w(), width);
272  wstart = max(wstart, 0);
273  const int pool_index = (n * pooled_height + ph) * pooled_width + pw;
274  const float scale = 1. / (hend - hstart) / (wend - wstart);
275  for (int h = hstart; h < hend; ++h) {
276  for (int w = wstart; w < wend; ++w) {
277  const int input_index = (n * height + h) * width + w;
278  PoolType::process_grad(
279  pool_index, input_index, scale, Xmat, Ymat, dYmat, dXmat);
280  }
281  }
282  }
283  }
284  }
285  break;
286  case 3:
287  for (int n = 0; n < X.dim32(0); ++n) {
288  for (int ph = 0; ph < pooled_height; ++ph) {
289  int hstart = ph * stride_h() - pad_t();
290  int hend = min(hstart + kernel_h(), height);
291  hstart = max(hstart, 0);
292  for (int pw = 0; pw < pooled_width; ++pw) {
293  int wstart = pw * stride_w() - pad_l();
294  int wend = min(wstart + kernel_w(), width);
295  wstart = max(wstart, 0);
296  for (int pd = 0; pd < pooled_depth; ++pd) {
297  int dstart = pd * stride_[2] - pads_[2];
298  int dend = min(dstart + kernel_[2], depth);
299  dstart = max(dstart, 0);
300  const int pool_index =
301  ((n * pooled_height + ph) * pooled_width + pw) *
302  pooled_depth +
303  pd;
304  const float scale =
305  1. / (hend - hstart) / (wend - wstart) / (dend - dstart);
306  for (int h = hstart; h < hend; ++h) {
307  for (int w = wstart; w < wend; ++w) {
308  for (int d = dstart; d < dend; ++d) {
309  const int input_index =
310  ((n * height + h) * width + w) * depth + d;
311  PoolType::process_grad(
312  pool_index,
313  input_index,
314  scale,
315  Xmat,
316  Ymat,
317  dYmat,
318  dXmat);
319  }
320  }
321  }
322  }
323  }
324  }
325  }
326  break;
327  default:
328  CAFFE_THROW("Unsupported pooling size");
329  return false;
330  }
331  return true;
332 }
333 
334 REGISTER_CPU_OPERATOR(
335  AveragePoolGradient,
336  PoolGradientOp<float, CPUContext, AveragePool<float>>);
337 OPERATOR_SCHEMA(AveragePoolGradient).NumInputs(3).NumOutputs(1);
338 
339 REGISTER_CPU_OPERATOR(
340  AveragePool1DGradient,
341  PoolGradientOp<float, CPUContext, AveragePool<float>>);
342 OPERATOR_SCHEMA(AveragePool1DGradient).NumInputs(3).NumOutputs(1);
343 
344 REGISTER_CPU_OPERATOR(
345  AveragePool2DGradient,
346  PoolGradientOp<float, CPUContext, AveragePool<float>>);
347 OPERATOR_SCHEMA(AveragePool2DGradient).NumInputs(3).NumOutputs(1);
348 
349 REGISTER_CPU_OPERATOR(
350  AveragePool3DGradient,
351  PoolGradientOp<float, CPUContext, AveragePool<float>>);
352 OPERATOR_SCHEMA(AveragePool3DGradient).NumInputs(3).NumOutputs(1);
353 
354 REGISTER_CPU_OPERATOR(
355  MaxPoolGradient,
356  PoolGradientOp<float, CPUContext, MaxPool<float>>);
357 OPERATOR_SCHEMA(MaxPoolGradient).NumInputs(3).NumOutputs(1);
358 
359 REGISTER_CPU_OPERATOR(
360  MaxPool1DGradient,
361  PoolGradientOp<float, CPUContext, MaxPool<float>>);
362 OPERATOR_SCHEMA(MaxPool1DGradient).NumInputs(3).NumOutputs(1);
363 
364 REGISTER_CPU_OPERATOR(
365  MaxPool2DGradient,
366  PoolGradientOp<float, CPUContext, MaxPool<float>>);
367 OPERATOR_SCHEMA(MaxPool2DGradient).NumInputs(3).NumOutputs(1);
368 
369 REGISTER_CPU_OPERATOR(
370  MaxPool3DGradient,
371  PoolGradientOp<float, CPUContext, MaxPool<float>>);
372 OPERATOR_SCHEMA(MaxPool3DGradient).NumInputs(3).NumOutputs(1);
373 
374 class GetPoolGradient : public GradientMakerBase {
375  using GradientMakerBase::GradientMakerBase;
376  vector<OperatorDef> GetGradientDefs() override {
377  return SingleGradientDef(
378  def_.type() + "Gradient",
379  "",
380  vector<string>{I(0), O(0), GO(0)},
381  vector<string>{GI(0)});
382  }
383 };
384 REGISTER_GRADIENT(AveragePool, GetPoolGradient);
385 REGISTER_GRADIENT(AveragePool1D, GetPoolGradient);
386 REGISTER_GRADIENT(AveragePool2D, GetPoolGradient);
387 REGISTER_GRADIENT(AveragePool3D, GetPoolGradient);
388 REGISTER_GRADIENT(MaxPool, GetPoolGradient);
389 REGISTER_GRADIENT(MaxPool1D, GetPoolGradient);
390 REGISTER_GRADIENT(MaxPool2D, GetPoolGradient);
391 REGISTER_GRADIENT(MaxPool3D, GetPoolGradient);
392 }
Copyright (c) 2016-present, Facebook, Inc.