Caffe2 - C++ API
A deep learning, cross platform ML framework
deform_conv_gradient_op.cc
1 
17 #include "caffe2/operators/conv_pool_op_base.h"
18 #include "caffe2/operators/deform_conv_op.h"
19 #include "caffe2/operators/deform_conv_op_impl.h"
20 
21 namespace caffe2 {
22 
23 OPERATOR_SCHEMA(DeformConvGradient).NumInputs(4, 4).NumOutputs(2, 4);
24 
25 namespace {
26 
27 class GetDeformConvGradient : public GradientMakerBase {
28  using GradientMakerBase::GradientMakerBase;
29  vector<OperatorDef> GetGradientDefs() override {
30  CAFFE_ENFORCE(def_.input_size() == 3 || def_.input_size() == 4);
31 
32  ArgumentHelper argsHelper(def_);
33 
34  auto compute_dX =
35  !argsHelper.GetSingleArgument<bool>("no_gradient_to_input", 0);
36 
37  if (def_.input_size() == 4) {
38  if (compute_dX) {
39  return SingleGradientDef(
40  "DeformConvGradient",
41  "",
42  vector<string>{I(0), I(1), I(2), GO(0)},
43  vector<string>{GI(1), GI(2), GI(3), GI(0)});
44  } else {
45  return SingleGradientDef(
46  "DeformConvGradient",
47  "",
48  vector<string>{I(0), I(1), I(2), GO(0)},
49  vector<string>{GI(1), GI(2), GI(3)});
50  }
51  } else {
52  if (compute_dX) {
53  return SingleGradientDef(
54  "DeformConvGradient",
55  "",
56  vector<string>{I(0), I(1), I(2), GO(0)},
57  vector<string>{GI(1), GI(2), GI(0)},
58  vector<Argument>{MakeArgument<int>("no_bias", 1)});
59  } else {
60  return SingleGradientDef(
61  "DeformConvGradient",
62  "",
63  vector<string>{I(0), I(1), I(2), GO(0)},
64  vector<string>{GI(1), GI(2)},
65  vector<Argument>{MakeArgument<int>("no_bias", 1)});
66  }
67  }
68  }
69 };
70 REGISTER_GRADIENT(DeformConv, GetDeformConvGradient);
71 
72 } // namespace
73 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.