Caffe2 - C++ API
A deep learning, cross platform ML framework
roi_pool_op.h
1 #ifndef ROI_POOL_OP_H_
2 #define ROI_POOL_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context>
12 class RoIPoolOp final : public Operator<Context> {
13  public:
14  template <class... Args>
15  explicit RoIPoolOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  is_test_(
18  this->template GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
19  order_(StringToStorageOrder(
20  this->template GetSingleArgument<string>("order", "NCHW"))),
21  pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
22  pooled_width_(this->template GetSingleArgument<int>("pooled_w", 1)),
23  spatial_scale_(
24  this->template GetSingleArgument<float>("spatial_scale", 1.)) {
25  CAFFE_ENFORCE(
26  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 2),
27  "Output size mismatch.");
28  CAFFE_ENFORCE_GT(spatial_scale_, 0);
29  CAFFE_ENFORCE_GT(pooled_height_, 0);
30  CAFFE_ENFORCE_GT(pooled_width_, 0);
31  CAFFE_ENFORCE_EQ(
32  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
33  }
34  USE_OPERATOR_CONTEXT_FUNCTIONS;
35 
36  bool RunOnDevice() override;
37 
38  protected:
39  bool is_test_;
40  StorageOrder order_;
41  int pooled_height_;
42  int pooled_width_;
43  float spatial_scale_;
44 };
45 
46 template <typename T, class Context>
47 class RoIPoolGradientOp final : public Operator<Context> {
48  public:
49  template <class... Args>
50  explicit RoIPoolGradientOp(Args&&... args)
51  : Operator<Context>(std::forward<Args>(args)...),
52  spatial_scale_(
53  this->template GetSingleArgument<float>("spatial_scale", 1.)),
54  pooled_height_(this->template GetSingleArgument<int>("pooled_h", 1)),
55  pooled_width_(this->template GetSingleArgument<int>("pooled_w", 1)),
56  order_(StringToStorageOrder(
57  this->template GetSingleArgument<string>("order", "NCHW"))) {
58  CAFFE_ENFORCE_GT(spatial_scale_, 0);
59  CAFFE_ENFORCE_GT(pooled_height_, 0);
60  CAFFE_ENFORCE_GT(pooled_width_, 0);
61  CAFFE_ENFORCE_EQ(
62  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
63  }
64  USE_OPERATOR_CONTEXT_FUNCTIONS;
65 
66  bool RunOnDevice() override {
67  CAFFE_NOT_IMPLEMENTED;
68  }
69 
70  protected:
71  float spatial_scale_;
72  int pooled_height_;
73  int pooled_width_;
74  StorageOrder order_;
75 };
76 
77 } // namespace caffe2
78 
79 #endif // ROI_POOL_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13