Caffe2 - C++ API
A deep learning, cross platform ML framework
operator_gradient.h
1 #ifndef CAFFE2_CORE_OPERATOR_GRADIENT_H_
2 #define CAFFE2_CORE_OPERATOR_GRADIENT_H_
3 
4 #include "c10/util/Registry.h"
5 #include "caffe2/core/operator_schema.h"
6 #include "caffe2/proto/caffe2_pb.h"
7 #include "caffe2/utils/proto_utils.h"
8 
9 namespace caffe2 {
10 
11 /* @brief A struct that abstracts on top of dense and sparse blobs.
12  *
13  * For a dense blob, its gradient name should be written into dense_, and for
14  * a sparse blob, its gradient name should be written into indice_ for
15  * the sparse indices and value_ for the values.
16  */
17 struct CAFFE2_API GradientWrapper {
18  string dense_;
19  string indices_;
20  string values_;
21 
22  inline bool IsDense() const {
23  return (dense_.size() != 0);
24  }
25  inline bool IsSparse() const {
26  return (indices_.size() != 0 || values_.size() != 0);
27  }
28  inline bool IsEmpty() const {
29  return (!IsDense() && !IsSparse());
30  }
31 };
32 
36 struct CAFFE2_API GradientOpsMeta {
37  vector<OperatorDef> ops_;
38  vector<GradientWrapper> g_input_;
39 
40  GradientOpsMeta() {}
42  const vector<OperatorDef>& ops,
43  const vector<GradientWrapper>& v)
44  : ops_(ops), g_input_(v) {}
45 };
46 
47 class CAFFE2_API GradientMakerBase {
48  public:
50  const OperatorDef& def,
51  const vector<GradientWrapper>& g_output)
52  : def_(def), g_output_(g_output), g_input_(def.input_size()){};
53  virtual ~GradientMakerBase() {}
54  virtual bool CopyDeviceOption() const {
55  return true;
56  }
57  virtual bool CopyEngine() const {
58  return true;
59  }
60  virtual bool CopyArguments() const {
61  return true;
62  }
63 
64  virtual void VerifyOp() const {
65  auto* schema = OpSchemaRegistry::Schema(def_.type());
66  if (schema) {
67  CAFFE_ENFORCE(
68  schema->Verify(def_),
69  "(GradientMaker) Operator def did not pass schema checking: ",
70  ProtoDebugString(def_));
71  }
72  }
73 
85  virtual GradientOpsMeta Get() {
86  VerifyOp();
87  vector<OperatorDef> new_defs = GetGradientDefs();
88  for (auto& opdef : new_defs) {
89  opdef.set_is_gradient_op(true);
90  }
91  return GradientOpsMeta(new_defs, g_input_);
92  };
93 
94  const OperatorDef& Def() const {
95  return def_;
96  }
97 
98  protected:
99  virtual vector<OperatorDef> GetGradientDefs() {
100  CAFFE_NOT_IMPLEMENTED;
101  }
102 
103  // Helper functions to return names for the gradient computation.
104  // I(idx), O(idx): return the input and output names.
105  // GO(idx): return the name of the gradient for output idx.
106  // GI(idx), GI_I(idx), GI_V(idx): return the name of the gradient for
107  // input idx, and also registers that name into the gradient
108  // registry to be returned.
109  string I(const int i) {
110  CAFFE_ENFORCE((i >= 0) && (i < def_.input().size()));
111  return def_.input(i);
112  }
113  string O(const int i) {
114  CAFFE_ENFORCE((i >= 0) && (i < def_.output().size()));
115  return def_.output(i);
116  }
117  string GI(const int i) {
118  CAFFE_ENFORCE(
119  !g_input_.at(i).IsSparse(),
120  "Input ",
121  def_.input(i),
122  " already set to sparse.");
123  g_input_.at(i).dense_ = GradientName(def_.input(i));
124  return GradientName(def_.input(i));
125  }
126  string GI_I(const int i) {
127  CAFFE_ENFORCE(
128  !g_input_.at(i).IsDense(),
129  "Input ",
130  def_.input(i),
131  " already set to dense.");
132  g_input_.at(i).indices_ = GradientSliceIndices(def_.input(i));
133  return GradientSliceIndices(def_.input(i));
134  }
135  string GI_V(const int i) {
136  CAFFE_ENFORCE(
137  !g_input_.at(i).IsDense(),
138  "Input ",
139  def_.input(i),
140  " already set to dense.");
141  g_input_.at(i).values_ = GradientSliceValues(def_.input(i));
142  return GradientSliceValues(def_.input(i));
143  }
144  string GO(const int i) {
145  CAFFE_ENFORCE(
146  g_output_.at(i).IsDense(),
147  "Gradient of output ",
148  def_.output(i),
149  (g_output_.at(i).IsSparse() ? " is sparse (expected dense)."
150  : " is not provided!"));
151  return g_output_.at(i).dense_;
152  }
153  string GO_I(const int i) {
154  CAFFE_ENFORCE(
155  g_output_.at(i).IsSparse(),
156  "Gradient of output ",
157  def_.output(i),
158  (g_output_.at(i).IsDense() ? " is dense (expected sparse)."
159  : " is not provided!"));
160  return g_output_.at(i).indices_;
161  }
162  string GO_V(const int i) {
163  CAFFE_ENFORCE(
164  g_output_.at(i).IsSparse(),
165  "Gradient of output ",
166  def_.output(i),
167  (g_output_.at(i).IsDense() ? " is dense (expected sparse)."
168  : " is not provided!"));
169  return g_output_.at(i).values_;
170  }
171  const GradientWrapper& GradOut(int i) {
172  return g_output_.at(i);
173  }
174 
175  // Function to add a gradient pair to map.
176  void SetDense(const int i, const string& name) {
177  CAFFE_ENFORCE(
178  !g_input_.at(i).IsSparse(),
179  "Input ",
180  def_.input(i),
181  " already set to sparse.");
182  g_input_.at(i).dense_ = name;
183  }
184  void SetSparse(const int i, const string& indices, const string& values) {
185  CAFFE_ENFORCE(
186  !g_input_.at(i).IsDense(),
187  "Input ",
188  def_.input(i),
189  " already set to dense.");
190  g_input_.at(i).indices_ = indices;
191  g_input_.at(i).values_ = values;
192  }
193 
198  template <class... Args>
199  inline static vector<OperatorDef> SingleGradientDef(const Args&... args) {
200  return vector<OperatorDef>{CreateOperatorDef(args...)};
201  }
202 
203  public:
207  static CaffeMap<string, string> MatchGradsToParams(const OperatorDef& op) {
208  // NOTE: how to go beyond string-matching?
209  CaffeMap<string, string> m;
210  for (auto& out : op.output()) {
211  if (IsGradientBlob(out)) {
212  m[out] = out.substr(0, out.length() - 5);
213  }
214  }
215  return m;
216  }
217 
218  private:
219  // Utility functions for gradient name computation. We don't expose them
220  // in order to discourage the use of such names explicitly.
221  static string GradientName(const string& name) {
222  return name + "_grad";
223  }
224 
225  static bool IsGradientBlob(const string& name) {
226  return name.length() > 5 && name.find("_grad") == name.length() - 5;
227  }
228 
229  static string GradientNameToParam(const string& name) {
230  CHECK(IsGradientBlob(name));
231  return name.substr(0, name.length() - 5);
232  }
233 
234  static string GradientSliceIndices(const string& name) {
235  return name + "_grad_indices";
236  }
237 
238  static string GradientSliceValues(const string& name) {
239  return name + "_grad_values";
240  }
241 
242  protected:
243  // We make the member variables protected in case someone wants to write
244  // a fully custom Get() function.
245  const OperatorDef& def_;
246  const vector<GradientWrapper>& g_output_;
247  vector<GradientWrapper> g_input_;
248 };
249 
259 class CAFFE2_API NoGradient : public GradientMakerBase {
260  using GradientMakerBase::GradientMakerBase;
261  vector<OperatorDef> GetGradientDefs() override {
262  return vector<OperatorDef>();
263  }
264 };
265 
273  using GradientMakerBase::GradientMakerBase;
274  GradientOpsMeta Get() override {
275  CAFFE_ENFORCE(
276  false, "One should not call gradient for operator ", def_.type(), ".");
277  }
278 };
279 
288  using GradientMakerBase::GradientMakerBase;
289  GradientOpsMeta Get() override {
290  CAFFE_ENFORCE(
291  false,
292  "Operator ",
293  def_.type(),
294  " should have a gradient but is not implemented yet.");
295  }
296 };
297 
298 C10_DECLARE_REGISTRY(
299  GradientRegistry,
301  const OperatorDef&,
302  const vector<GradientWrapper>&);
303 
304 #ifdef CAFFE2_NO_GRADIENT_OPS
305 
306 #define REGISTER_GRADIENT(name, ...) /* No gradients. */
307 #define REGISTER_GRADIENT_STR(str_name, ...) /* No gradients. */
308 
309 #else
310 
311 #define REGISTER_GRADIENT(name, ...) \
312  C10_REGISTER_CLASS(GradientRegistry, name, __VA_ARGS__)
313 #define REGISTER_GRADIENT_STR(str_name, ...) \
314  C10_REGISTER_TYPED_CLASS(GradientRegistry, str_name, __VA_ARGS__)
315 
316 #endif
317 
318 // NO_GRADIENT means that the operator does not need any gradient computation.
319 #define NO_GRADIENT(name) REGISTER_GRADIENT(name, NoGradient)
320 
321 // SHOULD_NOT_DO_GRADIENT means that the operator is not designed to have
322 // gradient operators. If you attempt to call the gradient, a log fatal will
323 // occur.
324 #define SHOULD_NOT_DO_GRADIENT(name) \
325  REGISTER_GRADIENT(name, ThrowInTheTowelIfGradientIsCalled)
326 
327 #define GRADIENT_NOT_IMPLEMENTED_YET(name) \
328  REGISTER_GRADIENT(name, GradientNotImplementedYet)
329 
334  const OperatorDef& def,
335  const vector<GradientWrapper>& g_output);
336 
337 } // namespace caffe2
338 
339 #endif // CAFFE2_CORE_OPERATOR_GRADIENT_H_
A helper class to indicate that the gradient mechanism is not ready.
static CaffeMap< string, string > MatchGradsToParams(const OperatorDef &op)
Returns map that returns the parameters that the gradients are for.
GradientOpsMeta GetGradientForOp(const OperatorDef &def, const vector< GradientWrapper > &g_output)
Gets the GradientOpsMeta for the given operator def.
Definition: operator.cc:376
A struct that holds the gradient operators and related gradient maps.
A helper class to indicate that the operator should have no gradient.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static vector< OperatorDef > SingleGradientDef(const Args &...args)
a helper function to allow one to create one single operator def, which is usually the case for many ...
virtual GradientOpsMeta Get()
Returns the gradient ops meta.
GradientOpsMeta Get() override
Returns the gradient ops meta.
A helper class to indicate that the operator does not need gradient computation.
GradientOpsMeta Get() override
Returns the gradient ops meta.