Caffe2 - C++ API
A deep learning, cross platform ML framework
roi_pool_op.h
1 
17 #ifndef ROI_POOL_OP_H_
18 #define ROI_POOL_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class RoIPoolOp final : public Operator<Context> {
29  public:
30  RoIPoolOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
33  order_(StringToStorageOrder(
34  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
35  pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
36  pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
37  spatial_scale_(
38  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)) {
39  CAFFE_ENFORCE(
40  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 2),
41  "Output size mismatch.");
42  CAFFE_ENFORCE_GT(spatial_scale_, 0);
43  CAFFE_ENFORCE_GT(pooled_height_, 0);
44  CAFFE_ENFORCE_GT(pooled_width_, 0);
45  CAFFE_ENFORCE_EQ(
46  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
47  }
48  USE_OPERATOR_CONTEXT_FUNCTIONS;
49 
50  bool RunOnDevice() override;
51 
52  protected:
53  bool is_test_;
54  StorageOrder order_;
55  int pooled_height_;
56  int pooled_width_;
57  float spatial_scale_;
58 };
59 
60 template <typename T, class Context>
61 class RoIPoolGradientOp final : public Operator<Context> {
62  public:
63  RoIPoolGradientOp(const OperatorDef& def, Workspace* ws)
64  : Operator<Context>(def, ws),
65  spatial_scale_(
66  OperatorBase::GetSingleArgument<float>("spatial_scale", 1.)),
67  pooled_height_(OperatorBase::GetSingleArgument<int>("pooled_h", 1)),
68  pooled_width_(OperatorBase::GetSingleArgument<int>("pooled_w", 1)),
69  order_(StringToStorageOrder(
70  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
71  CAFFE_ENFORCE_GT(spatial_scale_, 0);
72  CAFFE_ENFORCE_GT(pooled_height_, 0);
73  CAFFE_ENFORCE_GT(pooled_width_, 0);
74  CAFFE_ENFORCE_EQ(
75  order_, StorageOrder::NCHW, "Only NCHW order is supported right now.");
76  }
77  USE_OPERATOR_CONTEXT_FUNCTIONS;
78 
79  bool RunOnDevice() override {
80  CAFFE_NOT_IMPLEMENTED;
81  }
82 
83  protected:
84  float spatial_scale_;
85  int pooled_height_;
86  int pooled_width_;
87  StorageOrder order_;
88 };
89 
90 } // namespace caffe2
91 
92 #endif // ROI_POOL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.