Caffe2 - C++ API
A deep learning, cross platform ML framework
deform_conv_op.h
1 
17 #ifndef CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
18 #define CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/operators/conv_op_shared.h"
23 #include "caffe2/operators/conv_pool_op_base.h"
24 
25 CAFFE2_DECLARE_bool(caffe2_force_shared_col_buffer);
26 
27 namespace caffe2 {
28 
29 template <typename T, class Context>
30 class DeformConvOpBase : public ConvPoolOpBase<Context> {
31  public:
32  USE_CONV_POOL_BASE_FUNCTIONS(Context);
33  DeformConvOpBase(const OperatorDef& operator_def, Workspace* ws)
34  : ConvPoolOpBase<Context>(operator_def, ws),
35  deformable_group_(
36  OperatorBase::GetSingleArgument<int>("deformable_group", 1)) {}
37  ~DeformConvOpBase() {}
38 
39  protected:
40  void DeformableIm2col(
41  const T* data_im,
42  const T* data_offset,
43  const std::vector<TIndex>& im_shape,
44  const std::vector<TIndex>& col_shape,
45  T* data_col);
46  void DeformableCol2im(
47  const T* data_col,
48  const T* data_offset,
49  const std::vector<TIndex>& im_shape,
50  const std::vector<TIndex>& col_shape,
51  T* grad_im);
52  void DeformableCol2imCoord(
53  const T* data_col,
54  const T* data_im,
55  const T* data_offset,
56  const std::vector<TIndex>& im_shape,
57  const std::vector<TIndex>& col_shape,
58  T* grad_offset);
59 
60  protected:
61  int deformable_group_;
62 
63 #define USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context) \
64  USE_CONV_POOL_BASE_FUNCTIONS(Context); \
65  using DeformConvOpBase<T, Context>::deformable_group_; \
66  using DeformConvOpBase<T, Context>::DeformableIm2col; \
67  using DeformConvOpBase<T, Context>::DeformableCol2im; \
68  using DeformConvOpBase<T, Context>::DeformableCol2imCoord
69 };
70 
71 template <typename T, class Context>
72 class DeformConvOp final : public DeformConvOpBase<T, Context> {
73  public:
74  USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context);
75 
76  DeformConvOp(const OperatorDef& operator_def, Workspace* ws)
77  : DeformConvOpBase<T, Context>(operator_def, ws) {
78  // Create shared buffer mutex in the constructor
79  // to avoid race-condition in DAGNet.
80  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
81  createSharedBuffer<Context>(ws_);
82  }
83  }
84  ~DeformConvOp() {}
85 
86  bool RunOnDeviceWithOrderNCHW() override;
87 
88  private:
89  Tensor<Context> col_buffer_;
90  Tensor<Context> bias_multiplier_;
91  Tensor<Context> img_shape_device_;
92  Tensor<Context> col_buffer_shape_device_;
93  // Input: X, o, W, b
94  // Output: Y
95  INPUT_TAGS(INPUT, OFFSET, FILTER, BIAS);
96 };
97 
98 template <typename T, class Context>
99 class DeformConvGradientOp final : public DeformConvOpBase<T, Context> {
100  public:
101  USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context);
102 
103  DeformConvGradientOp(const OperatorDef& operator_def, Workspace* ws)
104  : DeformConvOpBase<T, Context>(operator_def, ws),
105  no_bias_(OperatorBase::GetSingleArgument<int>("no_bias", 0)) {
106  CAFFE_ENFORCE(
107  !(no_bias_ && OutputSize() == 4),
108  "If bias is not present, you should not have 4 grad output.");
109  }
111 
112  bool RunOnDeviceWithOrderNCHW() override;
113 
114  private:
115  Tensor<Context> col_buffer_;
116  Tensor<Context> bias_multiplier_;
117  Tensor<Context> img_shape_device_;
118  Tensor<Context> col_buffer_shape_device_;
119  bool no_bias_;
120  // input: X, W, dY
121  // output: dO, dW, db, and optionally dX
122  INPUT_TAGS(INPUT, OFFSET, FILTER, OUTPUT_GRAD);
123  OUTPUT_TAGS(OFFSET_GRAD, FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
124 };
125 
126 } // namespace caffe2
127 
128 #endif // CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.