Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_gradient_op.cc
1 
17 #include "caffe2/operators/conv_op.h"
18 #include "caffe2/operators/conv_op_impl.h"
19 #include "caffe2/operators/conv_pool_op_base.h"
20 
21 namespace caffe2 {
22 
23 REGISTER_CPU_OPERATOR(ConvGradient, ConvGradientOp<float, CPUContext>);
24 OPERATOR_SCHEMA(ConvGradient).NumInputs(2, 3).NumOutputs(1, 3);
25 
26 REGISTER_CPU_OPERATOR(Conv1DGradient, ConvGradientOp<float, CPUContext>);
27 OPERATOR_SCHEMA(Conv1DGradient).NumInputs(2, 3).NumOutputs(1, 3);
28 
29 REGISTER_CPU_OPERATOR(Conv2DGradient, ConvGradientOp<float, CPUContext>);
30 OPERATOR_SCHEMA(Conv2DGradient).NumInputs(2, 3).NumOutputs(1, 3);
31 
32 REGISTER_CPU_OPERATOR(Conv3DGradient, ConvGradientOp<float, CPUContext>);
33 OPERATOR_SCHEMA(Conv3DGradient).NumInputs(2, 3).NumOutputs(1, 3);
34 
36  using GradientMakerBase::GradientMakerBase;
37  vector<OperatorDef> GetGradientDefs() override {
38  CAFFE_ENFORCE(def_.input_size() == 3 || def_.input_size() == 2);
39 
40  ArgumentHelper argsHelper(def_);
41 
42  auto compute_dX = !argsHelper.GetSingleArgument<bool>("no_gradient_to_input", 0);
43 
44  if (def_.input_size() == 3) {
45  if (compute_dX) {
46  return SingleGradientDef(
47  def_.type() + "Gradient",
48  "",
49  vector<string>{I(0), I(1), GO(0)},
50  vector<string>{GI(1), GI(2), GI(0)});
51  } else {
52  return SingleGradientDef(
53  def_.type() + "Gradient",
54  "",
55  vector<string>{I(0), I(1), GO(0)},
56  vector<string>{GI(1), GI(2)});
57  }
58  } else {
59  if (compute_dX) {
60  return SingleGradientDef(
61  def_.type() + "Gradient",
62  "",
63  vector<string>{I(0), I(1), GO(0)},
64  vector<string>{GI(1), GI(0)},
65  vector<Argument>{MakeArgument<int>("no_bias", 1)});
66  } else {
67  return SingleGradientDef(
68  def_.type() + "Gradient",
69  "",
70  vector<string>{I(0), I(1), GO(0)},
71  vector<string>{GI(1)},
72  vector<Argument>{MakeArgument<int>("no_bias", 1)});
73  }
74  }
75  }
76 };
77 REGISTER_GRADIENT(Conv, GetConvGradient);
78 REGISTER_GRADIENT(Conv1D, GetConvGradient);
79 REGISTER_GRADIENT(Conv2D, GetConvGradient);
80 REGISTER_GRADIENT(Conv3D, GetConvGradient);
81 
82 } // namespace caffe2
A helper class to index into arguments.
Definition: proto_utils.h:198
Copyright (c) 2016-present, Facebook, Inc.
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...