Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_gradient_op.cc
1 #include "caffe2/operators/conv_transpose_op.h"
2 #include "caffe2/operators/conv_transpose_op_impl.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(
7  ConvTransposeGradient,
8  ConvTransposeGradientOp<float, CPUContext>);
9 
10 OPERATOR_SCHEMA(ConvTransposeGradient).NumInputs(3).NumOutputs(1, 3);
11 
13  using GradientMakerBase::GradientMakerBase;
14  vector<OperatorDef> GetGradientDefs() override {
15  auto compute_dX =
16  !ArgumentHelper::GetSingleArgument(def_, "no_gradient_to_input", false);
17 
18  CAFFE_ENFORCE(3 == def_.input_size() || 2 == def_.input_size());
19  if (def_.input_size() == 3 && compute_dX) {
20  return SingleGradientDef(
21  "ConvTransposeGradient",
22  "",
23  vector<string>{I(0), I(1), GO(0)},
24  vector<string>{GI(1), GI(2), GI(0)});
25  } else if (def_.input_size() == 3) {
26  return SingleGradientDef(
27  "ConvTransposeGradient",
28  "",
29  vector<string>{I(0), I(1), GO(0)},
30  vector<string>{GI(1), GI(2)});
31  } else if (compute_dX) {
32  return SingleGradientDef(
33  "ConvTransposeGradient",
34  "",
35  vector<string>{I(0), I(1), GO(0)},
36  vector<string>{GI(1), GI(0)},
37  vector<Argument>{MakeArgument<bool>("no_bias", true)});
38  } else {
39  return SingleGradientDef(
40  "ConvTransposeGradient",
41  "",
42  vector<string>{I(0), I(1), GO(0)},
43  vector<string>{GI(1)},
44  vector<Argument>{MakeArgument<bool>("no_bias", true)});
45  }
46  }
47 };
48 REGISTER_GRADIENT(ConvTranspose, GetConvTransposeGradient);
49 
50 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...