Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_op.h
1 #ifndef CAFFE2_OPERATORS_POOL_OP_H_
2 #define CAFFE2_OPERATORS_POOL_OP_H_
3 
4 #include <vector>
5 
6 #include "caffe2/core/common_omp.h"
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/logging.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/operators/conv_pool_op_base.h"
11 
12 namespace caffe2 {
13 
14 template <typename T, class Context, class Functor>
15 class PoolOp final : public ConvPoolOpBase<Context> {
16  public:
17  USE_CONV_POOL_BASE_FUNCTIONS(Context);
18 
19  template <class... Args>
20  explicit PoolOp(Args&&... args)
21  : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {
22  const int kernel_size = kernel_.size();
23  for (int i = 0; i < kernel_size; ++i) {
24  CAFFE_ENFORCE_EQ(
25  dilation_[i], 1, "Pooling op does not support dilation right now.");
26  }
27  if (!global_pooling_) {
28  for (int i = 0; i < kernel_size; ++i) {
29  CAFFE_ENFORCE(
30  pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i],
31  "Pad should be smaller than kernel.");
32  }
33  }
34  }
35 
36  ~PoolOp() = default;
37 
38  bool RunOnDeviceWithOrderNCHW() override {
39  const auto& X = Input(0);
40  auto* Y = Output(0);
41  const int N = X.dim32(0);
42  const int C = X.dim32(1);
44  const T* X_data = X.template data<T>();
45  T* Y_data = Y->template mutable_data<T>();
46  if (N == 0) {
47  return true;
48  }
49  if (global_pooling_) {
50  const int HxW = X.numel() / (N * C);
51  return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
52  N, C, HxW, X_data, Y_data, &context_);
53  }
54  const std::vector<int> X_HW_dims = GetDims(X);
55  const std::vector<int> Y_HW_dims = GetDims(*Y);
56  return functor_.template Forward<T, StorageOrder::NCHW>(
57  N,
58  C,
59  X_HW_dims,
60  Y_HW_dims,
61  kernel_,
62  dilation_,
63  stride_,
64  pads_,
65  X.template data<T>(),
66  Y->template mutable_data<T>(),
67  &context_);
68  }
69 
70  bool RunOnDeviceWithOrderNHWC() override {
71  const auto& X = Input(0);
72  auto* Y = Output(0);
73  const int ndim = X.dim();
74  const int N = X.dim32(0);
75  const int C = X.dim32(ndim - 1);
77  const T* X_data = X.template data<T>();
78  T* Y_data = Y->template mutable_data<T>();
79  if (N == 0) {
80  return true;
81  }
82  if (global_pooling_) {
83  const int HxW = X.numel() / (N * C);
84  return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
85  N, C, HxW, X_data, Y_data, &context_);
86  }
87  const std::vector<int> X_HW_dims = GetDims(X);
88  const std::vector<int> Y_HW_dims = GetDims(*Y);
89  return functor_.template Forward<T, StorageOrder::NHWC>(
90  N,
91  C,
92  X_HW_dims,
93  Y_HW_dims,
94  kernel_,
95  dilation_,
96  stride_,
97  pads_,
98  X.template data<T>(),
99  Y->template mutable_data<T>(),
100  &context_);
101  }
102 
103  private:
104  const Functor functor_;
105 };
106 
107 template <typename T, class Context, class Functor>
108 class PoolGradientOp final : public ConvPoolOpBase<Context> {
109  public:
110  USE_CONV_POOL_BASE_FUNCTIONS(Context);
111  template <class... Args>
112  explicit PoolGradientOp(Args&&... args)
113  : ConvPoolOpBase<Context>(std::forward<Args>(args)...), functor_(*this) {}
114 
115  ~PoolGradientOp() = default;
116 
117  bool RunOnDeviceWithOrderNCHW() override {
118  const auto& X = Input(0);
119  const auto& Y = Input(1);
120  const auto& dY = Input(2);
121  auto* dX = Output(0, X.sizes(), at::dtype<T>());
122  const int N = X.dim32(0);
123  const int C = X.dim32(1);
124  const std::vector<int> X_HW_dims = GetDims(X);
125  const std::vector<int> Y_HW_dims = GetDims(Y);
127  const T* dY_data = dY.template data<T>();
128  const T* X_data = X.template data<T>();
129  const T* Y_data = Y.template data<T>();
130  T* dX_data = dX->template mutable_data<T>();
131  if (N == 0) {
132  return true;
133  }
134  if (global_pooling_) {
135  const int HxW = X.numel() / (N * C);
136  return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
137  N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
138  }
139  return functor_.template Backward<T, StorageOrder::NCHW>(
140  N,
141  C,
142  X_HW_dims,
143  Y_HW_dims,
144  kernel_,
145  dilation_,
146  stride_,
147  pads_,
148  dY_data,
149  X_data,
150  Y_data,
151  dX_data,
152  &context_);
153  }
154 
155  bool RunOnDeviceWithOrderNHWC() override {
156  const auto& X = Input(0);
157  const auto& Y = Input(1);
158  const auto& dY = Input(2);
159  auto* dX = Output(0, X.sizes(), at::dtype<T>());
160  const int ndim = X.dim();
161  const int N = X.dim32(0);
162  const int C = X.dim32(ndim - 1);
163  const std::vector<int> X_HW_dims = GetDims(X);
164  const std::vector<int> Y_HW_dims = GetDims(Y);
166  const T* dY_data = dY.template data<T>();
167  const T* X_data = X.template data<T>();
168  const T* Y_data = Y.template data<T>();
169  T* dX_data = dX->template mutable_data<T>();
170  if (N == 0) {
171  return true;
172  }
173  if (global_pooling_) {
174  const int HxW = X.numel() / (N * C);
175  return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
176  N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
177  }
178  return functor_.template Backward<T, StorageOrder::NHWC>(
179  N,
180  C,
181  X_HW_dims,
182  Y_HW_dims,
183  kernel_,
184  dilation_,
185  stride_,
186  pads_,
187  dY_data,
188  X_data,
189  Y_data,
190  dX_data,
191  &context_);
192  }
193 
194  private:
195  const Functor functor_;
196 };
197 
198 template <class Context>
200  explicit AveragePoolFunctor(const OperatorBase& op)
201  : count_include_pad(
202  op.template GetSingleArgument<bool>("count_include_pad", false)) {}
203 
204  template <typename T, StorageOrder kOrder>
205  bool GlobalPoolingForward(
206  int N,
207  int C,
208  int HxW,
209  const T* X,
210  T* Y,
211  Context* context) const;
212 
213  template <typename T, StorageOrder kOrder>
214  bool Forward(
215  int N,
216  int C,
217  const std::vector<int>& X_dims,
218  const std::vector<int>& Y_dims,
219  const std::vector<int>& kernel,
220  const std::vector<int>& dilation,
221  const std::vector<int>& stride,
222  const std::vector<int>& pads,
223  const T* X,
224  T* Y,
225  Context* context) const;
226 
227  template <typename T, StorageOrder kOrder>
228  bool GlobalPoolingBackward(
229  int N,
230  int C,
231  int HxW,
232  const T* dY,
233  const T* X,
234  const T* Y,
235  T* dX,
236  Context* context) const;
237 
238  template <typename T, StorageOrder kOrder>
239  bool Backward(
240  int N,
241  int C,
242  const std::vector<int>& X_dims,
243  const std::vector<int>& Y_dims,
244  const std::vector<int>& kernel,
245  const std::vector<int>& dilation,
246  const std::vector<int>& stride,
247  const std::vector<int>& pads,
248  const T* dY,
249  const T* X,
250  const T* Y,
251  T* dX,
252  Context* context) const;
253 
254  const bool count_include_pad;
255  Tensor ones{Context::GetDeviceType()};
256 };
257 
258 template <class Context>
260  explicit MaxPoolFunctor(const OperatorBase& /* op */) {}
261 
262  template <typename T, StorageOrder kOrder>
263  bool GlobalPoolingForward(
264  int N,
265  int C,
266  int HxW,
267  const T* X,
268  T* Y,
269  Context* context) const;
270 
271  template <typename T, StorageOrder kOrder>
272  bool Forward(
273  int N,
274  int C,
275  const std::vector<int>& X_dims,
276  const std::vector<int>& Y_dims,
277  const std::vector<int>& kernel,
278  const std::vector<int>& dilation,
279  const std::vector<int>& stride,
280  const std::vector<int>& pads,
281  const T* X,
282  T* Y,
283  Context* context) const;
284 
285  template <typename T, StorageOrder kOrder>
286  bool GlobalPoolingBackward(
287  int N,
288  int C,
289  int HxW,
290  const T* dY,
291  const T* X,
292  const T* Y,
293  T* dX,
294  Context* context) const;
295 
296  template <typename T, StorageOrder kOrder>
297  bool Backward(
298  int N,
299  int C,
300  const std::vector<int>& X_dims,
301  const std::vector<int>& Y_dims,
302  const std::vector<int>& kernel,
303  const std::vector<int>& dilation,
304  const std::vector<int>& stride,
305  const std::vector<int>& pads,
306  const T* dY,
307  const T* X,
308  const T* Y,
309  T* dX,
310  Context* context) const;
311 };
312 
313 } // namespace caffe2
314 
315 #endif // CAFFE2_OPERATORS_POOL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64