1 #ifndef CAFFE2_OPERATORS_DEFORM_CONV_OP_H_ 2 #define CAFFE2_OPERATORS_DEFORM_CONV_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/operator.h" 6 #include "caffe2/operators/conv_op_shared.h" 7 #include "caffe2/operators/conv_pool_op_base.h" 9 C10_DECLARE_bool(caffe2_force_shared_col_buffer);
13 template <
typename T,
class Context>
16 USE_CONV_POOL_BASE_FUNCTIONS(Context);
20 this->
template GetSingleArgument<int>(
"deformable_group", 1)) {}
24 void DeformableIm2col(
30 void DeformableCol2im(
36 void DeformableCol2imCoord(
45 int deformable_group_;
47 #define USE_DEFORMABLE_CONV_BASE_FUNCTIONS(T, Context) \ 48 USE_CONV_POOL_BASE_FUNCTIONS(Context); \ 49 using DeformConvOpBase<T, Context>::deformable_group_; \ 50 using DeformConvOpBase<T, Context>::DeformableIm2col; \ 51 using DeformConvOpBase<T, Context>::DeformableCol2im; \ 52 using DeformConvOpBase<T, Context>::DeformableCol2imCoord 55 template <
typename T,
class Context>
58 USE_DEFORMABLE_CONV_BASE_FUNCTIONS(
T, Context);
64 if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
65 createSharedBuffer<Context>(ws_);
70 bool RunOnDeviceWithOrderNCHW()
override;
73 Tensor col_buffer_{Context::GetDeviceType()};
75 Tensor img_shape_device_{Context::GetDeviceType()};
76 Tensor col_buffer_shape_device_{Context::GetDeviceType()};
79 INPUT_TAGS(INPUT, OFFSET, FILTER, BIAS);
82 template <
typename T,
class Context>
85 USE_DEFORMABLE_CONV_BASE_FUNCTIONS(
T, Context);
89 no_bias_(this->
template GetSingleArgument<int>(
"no_bias", 0)) {
91 !(no_bias_ && OutputSize() == 4),
92 "If bias is not present, you should not have 4 grad output.");
96 bool RunOnDeviceWithOrderNCHW()
override;
101 Tensor img_shape_device_{Context::GetDeviceType()};
102 Tensor col_buffer_shape_device_{Context::GetDeviceType()};
106 INPUT_TAGS(INPUT, OFFSET, FILTER, OUTPUT_GRAD);
107 OUTPUT_TAGS(OFFSET_GRAD, FILTER_GRAD, BIAS_OR_INPUT_GRAD, INPUT_GRAD);
112 #endif // CAFFE2_OPERATORS_DEFORM_CONV_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...