Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_transpose_gradient_op.cc
1 
17 #include "caffe2/operators/conv_transpose_op.h"
18 #include "caffe2/operators/conv_transpose_op_impl.h"
19 
20 namespace caffe2 {
21 
22 REGISTER_CPU_OPERATOR(
23  ConvTransposeGradient,
24  ConvTransposeGradientOp<float, CPUContext>);
25 
26 OPERATOR_SCHEMA(ConvTransposeGradient).NumInputs(3).NumOutputs(1, 3);
27 
29  using GradientMakerBase::GradientMakerBase;
30  vector<OperatorDef> GetGradientDefs() override {
31  auto compute_dX =
32  !ArgumentHelper::GetSingleArgument(def_, "no_gradient_to_input", false);
33 
34  CAFFE_ENFORCE(3 == def_.input_size() || 2 == def_.input_size());
35  if (def_.input_size() == 3 && compute_dX) {
36  return SingleGradientDef(
37  "ConvTransposeGradient",
38  "",
39  vector<string>{I(0), I(1), GO(0)},
40  vector<string>{GI(1), GI(2), GI(0)});
41  } else if (def_.input_size() == 3) {
42  return SingleGradientDef(
43  "ConvTransposeGradient",
44  "",
45  vector<string>{I(0), I(1), GO(0)},
46  vector<string>{GI(1), GI(2)});
47  } else if (compute_dX) {
48  return SingleGradientDef(
49  "ConvTransposeGradient",
50  "",
51  vector<string>{I(0), I(1), GO(0)},
52  vector<string>{GI(1), GI(0)},
53  vector<Argument>{MakeArgument<bool>("no_bias", true)});
54  } else {
55  return SingleGradientDef(
56  "ConvTransposeGradient",
57  "",
58  vector<string>{I(0), I(1), GO(0)},
59  vector<string>{GI(1)},
60  vector<Argument>{MakeArgument<bool>("no_bias", true)});
61  }
62  }
63 };
64 REGISTER_GRADIENT(ConvTranspose, GetConvTransposeGradient);
65 
66 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...