Caffe2 - C++ API
A deep learning, cross platform ML framework
deform_conv_gradient_op.cc
1 #include "caffe2/operators/conv_pool_op_base.h"
2 #include "caffe2/operators/deform_conv_op.h"
3 #include "caffe2/operators/deform_conv_op_impl.h"
4 
5 namespace caffe2 {
6 
7 OPERATOR_SCHEMA(DeformConvGradient).NumInputs(4, 4).NumOutputs(2, 4);
8 
9 namespace {
10 
11 class GetDeformConvGradient : public GradientMakerBase {
12  using GradientMakerBase::GradientMakerBase;
13  vector<OperatorDef> GetGradientDefs() override {
14  CAFFE_ENFORCE(def_.input_size() == 3 || def_.input_size() == 4);
15 
16  ArgumentHelper argsHelper(def_);
17 
18  auto compute_dX =
19  !argsHelper.GetSingleArgument<bool>("no_gradient_to_input", 0);
20 
21  if (def_.input_size() == 4) {
22  if (compute_dX) {
23  return SingleGradientDef(
24  "DeformConvGradient",
25  "",
26  vector<string>{I(0), I(1), I(2), GO(0)},
27  vector<string>{GI(1), GI(2), GI(3), GI(0)});
28  } else {
29  return SingleGradientDef(
30  "DeformConvGradient",
31  "",
32  vector<string>{I(0), I(1), I(2), GO(0)},
33  vector<string>{GI(1), GI(2), GI(3)});
34  }
35  } else {
36  if (compute_dX) {
37  return SingleGradientDef(
38  "DeformConvGradient",
39  "",
40  vector<string>{I(0), I(1), I(2), GO(0)},
41  vector<string>{GI(1), GI(2), GI(0)},
42  vector<Argument>{MakeArgument<int>("no_bias", 1)});
43  } else {
44  return SingleGradientDef(
45  "DeformConvGradient",
46  "",
47  vector<string>{I(0), I(1), I(2), GO(0)},
48  vector<string>{GI(1), GI(2)},
49  vector<Argument>{MakeArgument<int>("no_bias", 1)});
50  }
51  }
52  }
53 };
54 REGISTER_GRADIENT(DeformConv, GetDeformConvGradient);
55 
56 } // namespace
57 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13