Caffe2 - C++ API
A deep learning, cross platform ML framework
pool_op.h
1 
17 #ifndef CAFFE2_OPERATORS_POOL_OP_H_
18 #define CAFFE2_OPERATORS_POOL_OP_H_
19 
20 #include "caffe2/core/common_omp.h"
21 #include "caffe2/core/context.h"
22 #include "caffe2/core/logging.h"
23 #include "caffe2/core/operator.h"
24 #include "caffe2/operators/conv_pool_op_base.h"
25 #include "caffe2/utils/math.h"
26 
27 namespace caffe2 {
28 
29 template <typename T, class Context, typename PoolType>
30 class PoolOp final : public ConvPoolOpBase<Context> {
31  public:
32  USE_CONV_POOL_BASE_FUNCTIONS(Context);
33  PoolOp(const OperatorDef& operator_def, Workspace* ws)
34  : ConvPoolOpBase<Context>(operator_def, ws) {
35  for (int i = 0; i < kernel_.size(); ++i) {
36  CAFFE_ENFORCE(
37  dilation_[i] == 1, "Pooling op does not support dilation right now.");
38  }
39  if (!global_pooling_) {
40  for (int i = 0; i < kernel_.size(); ++i) {
41  CAFFE_ENFORCE(
42  pads_[i] < kernel_[i] && pads_[i + kernel_.size()] < kernel_[i],
43  "Pad should be smaller than kernel.");
44  }
45  }
46  }
47  ~PoolOp() {}
48 
49  bool RunOnDeviceWithOrderNCHW() override;
50  bool RunOnDeviceWithOrderNHWC() override;
51 
52  // Input: X
53  // Output: Y
54 };
55 
56 template <typename T, class Context, class PoolType>
57 class PoolGradientOp final : public ConvPoolOpBase<Context> {
58  public:
59  USE_CONV_POOL_BASE_FUNCTIONS(Context);
60  PoolGradientOp(const OperatorDef& operator_def, Workspace* ws)
61  : ConvPoolOpBase<Context>(operator_def, ws) {}
62  ~PoolGradientOp() {}
63 
64  bool RunOnDeviceWithOrderNCHW() override;
65  bool RunOnDeviceWithOrderNHWC() override;
66 
67  // Input: X, Y, dY
68  // Output: dX
69 };
70 
71 } // namespace caffe2
72 
73 #endif // CAFFE2_OPERATORS_POOL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.