Caffe2 - C++ API
A deep learning, cross platform ML framework
deform_conv_op.h
1 #ifndef CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
2 #define CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/operators/conv_op_shared.h"
7 #include "caffe2/operators/conv_pool_op_base.h"
8 
9 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
10 
11 namespace caffe2 {
12 
13 template <typename T, class Context>
14 class DeformConvOpBase : public ConvPoolOpBase<Context> {
15  public:
16  USE_CONV_POOL_BASE_FUNCTIONS(Context);
17  explicit DeformConvOpBase(const OperatorDef& operator_def, Workspace* ws)
18  : ConvPoolOpBase<Context>(operator_def, ws),
19  deformable_group_(
20  this->template GetSingleArgument<int>("deformable_group", 1)) {}
21  ~DeformConvOpBase() {}
22 
23  protected:
24  void DeformableIm2col(
25  const T* data_im,
26  const T* data_offset,
27  at::IntArrayRef im_shape,
28  at::IntArrayRef col_shape,
29  T* data_col);
30  void DeformableCol2im(
31  const T* data_col,
32  const T* data_offset,
33  at::IntArrayRef im_shape,
34  at::IntArrayRef col_shape,
35  T* grad_im);
36  void DeformableCol2imCoord(
37  const T* data_col,
38  const T* data_im,
39  const T* data_offset,
40  at::IntArrayRef im_shape,
41  at::IntArrayRef col_shape,
42  T* grad_offset);
43 
44  protected:
45  int deformable_group_;
46 
47 #define USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context) \
48  USE_CONV_POOL_BASE_FUNCTIONS(Context); \
49  using DeformConvOpBase<T, Context>::deformable_group_; \
50  using DeformConvOpBase<T, Context>::DeformableIm2col; \
51  using DeformConvOpBase<T, Context>::DeformableCol2im; \
52  using DeformConvOpBase<T, Context>::DeformableCol2imCoord
53 };
54 
55 template <typename T, class Context>
56 class DeformConvOp final : public DeformConvOpBase<T, Context> {
57  public:
58  USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context);
59 
60  explicit DeformConvOp(const OperatorDef& operator_def, Workspace* ws)
61  : DeformConvOpBase<T, Context>(operator_def, ws) {
62  // Create shared buffer mutex in the constructor
63  // to avoid race-condition in DAGNet.
64  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
65  createSharedBuffer<Context>(ws_);
66  }
67  }
68  ~DeformConvOp() {}
69 
70  bool RunOnDeviceWithOrderNCHW() override;
71 
72  private:
73  Tensor col_buffer_{Context::GetDeviceType()};
74  Tensor bias_multiplier_;
75  Tensor img_shape_device_{Context::GetDeviceType()};
76  Tensor col_buffer_shape_device_{Context::GetDeviceType()};
77  // Input: X, o, W, b
78  // Output: Y
79  INPUT_TAGS(INPUT, OFFSET, FILTER, BIAS);
80 };
81 
82 template <typename T, class Context>
83 class DeformConvGradientOp final : public DeformConvOpBase<T, Context> {
84  public:
85  USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context);
86 
87  explicit DeformConvGradientOp(const OperatorDef& operator_def, Workspace* ws)
88  : DeformConvOpBase<T, Context>(operator_def, ws),
89  no_bias_(this->template GetSingleArgument<int>("no_bias", 0)) {
90  CAFFE_ENFORCE(
91  !(no_bias_ && OutputSize() == 4),
92  "If bias is not present, you should not have 4 grad output.");
93  }
95 
96  bool RunOnDeviceWithOrderNCHW() override;
97 
98  private:
99  Tensor col_buffer_;
100  Tensor bias_multiplier_;
101  Tensor img_shape_device_{Context::GetDeviceType()};
102  Tensor col_buffer_shape_device_{Context::GetDeviceType()};
103  bool no_bias_;
104  // input: X, W, dY
105  // output: dO, dW, db, and optionally dX
106  INPUT_TAGS(INPUT, OFFSET, FILTER, OUTPUT_GRAD);
107  OUTPUT_TAGS(OFFSET_GRAD, FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
108 };
109 
110 } // namespace caffe2
111 
112 #endif // CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13