1 #ifndef CAFFE2_OPERATORS_POOL_OP_H_ 2 #define CAFFE2_OPERATORS_POOL_OP_H_ 6 #include "caffe2/core/common_omp.h" 7 #include "caffe2/core/context.h" 8 #include "caffe2/core/logging.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/operators/conv_pool_op_base.h" 14 template <
typename T,
class Context,
class Functor>
17 USE_CONV_POOL_BASE_FUNCTIONS(Context);
19 template <
class... Args>
20 explicit PoolOp(Args&&... args)
22 const int kernel_size = kernel_.size();
23 for (
int i = 0; i < kernel_size; ++i) {
25 dilation_[i], 1,
"Pooling op does not support dilation right now.");
27 if (!global_pooling_) {
28 for (
int i = 0; i < kernel_size; ++i) {
30 pads_[i] < kernel_[i] && pads_[i + kernel_size] < kernel_[i],
31 "Pad should be smaller than kernel.");
38 bool RunOnDeviceWithOrderNCHW()
override {
39 const auto& X =
Input(0);
41 const int N = X.dim32(0);
42 const int C = X.dim32(1);
44 const T* X_data = X.template data<T>();
45 T* Y_data = Y->template mutable_data<T>();
49 if (global_pooling_) {
50 const int HxW = X.numel() / (N * C);
51 return functor_.template GlobalPoolingForward<T, StorageOrder::NCHW>(
52 N, C, HxW, X_data, Y_data, &context_);
54 const std::vector<int> X_HW_dims = GetDims(X);
55 const std::vector<int> Y_HW_dims = GetDims(*Y);
56 return functor_.template Forward<T, StorageOrder::NCHW>(
66 Y->template mutable_data<T>(),
70 bool RunOnDeviceWithOrderNHWC()
override {
71 const auto& X =
Input(0);
73 const int ndim = X.dim();
74 const int N = X.dim32(0);
75 const int C = X.dim32(ndim - 1);
77 const T* X_data = X.template data<T>();
78 T* Y_data = Y->template mutable_data<T>();
82 if (global_pooling_) {
83 const int HxW = X.numel() / (N * C);
84 return functor_.template GlobalPoolingForward<T, StorageOrder::NHWC>(
85 N, C, HxW, X_data, Y_data, &context_);
87 const std::vector<int> X_HW_dims = GetDims(X);
88 const std::vector<int> Y_HW_dims = GetDims(*Y);
89 return functor_.template Forward<T, StorageOrder::NHWC>(
99 Y->template mutable_data<T>(),
104 const Functor functor_;
107 template <
typename T,
class Context,
class Functor>
110 USE_CONV_POOL_BASE_FUNCTIONS(Context);
111 template <
class... Args>
115 ~PoolGradientOp() =
default;
117 bool RunOnDeviceWithOrderNCHW()
override {
118 const auto& X =
Input(0);
119 const auto& Y =
Input(1);
120 const auto& dY =
Input(2);
121 auto* dX = Output(0, X.sizes(), at::dtype<T>());
122 const int N = X.dim32(0);
123 const int C = X.dim32(1);
124 const std::vector<int> X_HW_dims = GetDims(X);
125 const std::vector<int> Y_HW_dims = GetDims(Y);
127 const T* dY_data = dY.template data<T>();
128 const T* X_data = X.template data<T>();
129 const T* Y_data = Y.template data<T>();
130 T* dX_data = dX->template mutable_data<T>();
134 if (global_pooling_) {
135 const int HxW = X.numel() / (N * C);
136 return functor_.template GlobalPoolingBackward<T, StorageOrder::NCHW>(
137 N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
139 return functor_.template Backward<T, StorageOrder::NCHW>(
155 bool RunOnDeviceWithOrderNHWC()
override {
156 const auto& X =
Input(0);
157 const auto& Y =
Input(1);
158 const auto& dY =
Input(2);
159 auto* dX = Output(0, X.sizes(), at::dtype<T>());
160 const int ndim = X.dim();
161 const int N = X.dim32(0);
162 const int C = X.dim32(ndim - 1);
163 const std::vector<int> X_HW_dims = GetDims(X);
164 const std::vector<int> Y_HW_dims = GetDims(Y);
166 const T* dY_data = dY.template data<T>();
167 const T* X_data = X.template data<T>();
168 const T* Y_data = Y.template data<T>();
169 T* dX_data = dX->template mutable_data<T>();
173 if (global_pooling_) {
174 const int HxW = X.numel() / (N * C);
175 return functor_.template GlobalPoolingBackward<T, StorageOrder::NHWC>(
176 N, C, HxW, dY_data, X_data, Y_data, dX_data, &context_);
178 return functor_.template Backward<T, StorageOrder::NHWC>(
195 const Functor functor_;
198 template <
class Context>
202 op.template GetSingleArgument<bool>(
"count_include_pad",
false)) {}
204 template <
typename T, StorageOrder kOrder>
205 bool GlobalPoolingForward(
211 Context* context)
const;
213 template <
typename T, StorageOrder kOrder>
217 const std::vector<int>& X_dims,
218 const std::vector<int>& Y_dims,
219 const std::vector<int>& kernel,
220 const std::vector<int>& dilation,
221 const std::vector<int>& stride,
222 const std::vector<int>& pads,
225 Context* context)
const;
227 template <
typename T, StorageOrder kOrder>
228 bool GlobalPoolingBackward(
236 Context* context)
const;
238 template <
typename T, StorageOrder kOrder>
242 const std::vector<int>& X_dims,
243 const std::vector<int>& Y_dims,
244 const std::vector<int>& kernel,
245 const std::vector<int>& dilation,
246 const std::vector<int>& stride,
247 const std::vector<int>& pads,
252 Context* context)
const;
254 const bool count_include_pad;
255 Tensor ones{Context::GetDeviceType()};
258 template <
class Context>
262 template <
typename T, StorageOrder kOrder>
263 bool GlobalPoolingForward(
269 Context* context)
const;
271 template <
typename T, StorageOrder kOrder>
275 const std::vector<int>& X_dims,
276 const std::vector<int>& Y_dims,
277 const std::vector<int>& kernel,
278 const std::vector<int>& dilation,
279 const std::vector<int>& stride,
280 const std::vector<int>& pads,
283 Context* context)
const;
285 template <
typename T, StorageOrder kOrder>
286 bool GlobalPoolingBackward(
294 Context* context)
const;
296 template <
typename T, StorageOrder kOrder>
300 const std::vector<int>& X_dims,
301 const std::vector<int>& Y_dims,
302 const std::vector<int>& kernel,
303 const std::vector<int>& dilation,
304 const std::vector<int>& stride,
305 const std::vector<int>& pads,
310 Context* context)
const;
315 #endif // CAFFE2_OPERATORS_POOL_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...