Caffe2 - C++ API
A deep learning, cross platform ML framework
ps_roi_pool_op.h
1 
17 #ifndef PS_ROI_POOL_OP_H_
18 #define PS_ROI_POOL_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/logging.h"
22 #include "caffe2/core/operator.h"
23 #include "caffe2/utils/math.h"
24 
25 namespace caffe2 {
26 
27 template <typename T, class Context>
28 class PSRoIPoolOp final : public Operator<Context> {
29  public:
30  PSRoIPoolOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  spatial_scale_(this->template GetSingleArgument<float>(
33  "spatial_scale", 1.)),
34  group_size_(this->template GetSingleArgument<int>("group_size", 1)),
35  output_dim_(this->template GetSingleArgument<int>("output_dim", 1)) {
36  DCHECK_GT(spatial_scale_, 0);
37  DCHECK_GT(group_size_, 0);
38  pooled_height_ = group_size_;
39  pooled_width_ = group_size_;
40  }
41  USE_OPERATOR_CONTEXT_FUNCTIONS;
42 
43  bool RunOnDevice() override {
44  // No CPU implementation for now
45  CAFFE_NOT_IMPLEMENTED;
46  }
47 
48  protected:
49  float spatial_scale_;
50  int group_size_;
51  int output_dim_;
52  int pooled_height_;
53  int pooled_width_;
54  int channels_;
55  int height_;
56  int width_;
57  };
58 
59 template <typename T, class Context>
60 class PSRoIPoolGradientOp final : public Operator<Context> {
61  public:
62  PSRoIPoolGradientOp(const OperatorDef& def, Workspace* ws)
63  : Operator<Context>(def, ws),
64  spatial_scale_(this->template GetSingleArgument<float>(
65  "spatial_scale", 1.)),
66  group_size_(this->template GetSingleArgument<int>("group_size", 1)),
67  output_dim_(this->template GetSingleArgument<int>("output_dim", 1)) {
68  DCHECK_GT(spatial_scale_, 0);
69  DCHECK_GT(group_size_, 0);
70  pooled_height_ = group_size_;
71  pooled_width_ = group_size_;
72  }
73  USE_OPERATOR_CONTEXT_FUNCTIONS;
74 
75  bool RunOnDevice() override {
76  // No CPU implementation for now
77  CAFFE_NOT_IMPLEMENTED;
78  }
79 
80  protected:
81  float spatial_scale_;
82  int group_size_;
83  int output_dim_;
84  int pooled_height_;
85  int pooled_width_;
86  int channels_;
87  int height_;
88  int width_;
89 };
90 
91 } // namespace caffe2
92 
93 #endif // PS_ROI_POOL_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13