Caffe2 - C++ API
A deep learning, cross platform ML framework
segment_reduction_op.cc
1 #include "caffe2/operators/segment_reduction_op.h"
2 
3 namespace caffe2 {
4 
5 OpSchema::Cost CostInferenceForSparseLengths(
6  const OperatorDef& def,
7  const vector<TensorShape>& inputs,
8  bool use_weight) {
9  int min_num_of_inputs = 3 + use_weight;
10  CAFFE_ENFORCE_GE(
11  inputs.size(),
12  min_num_of_inputs,
13  def.type() + " requires at least " + c10::to_string(min_num_of_inputs));
14 
15  const TensorShape data = inputs[0];
16  const TensorShape indices = inputs[1 + use_weight];
17  const TensorShape lengths = inputs[2 + use_weight];
18 
19  OpSchema::Cost c;
20  CAFFE_ENFORCE_GT(data.dims_size(), 0, "data requires at least 1 dimension");
21  uint64_t N = data.dims(0);
22  if (N == 0) {
23  return c;
24  }
25  uint64_t D = nElemFromDim(data, 1);
26  CAFFE_ENFORCE_GT(
27  lengths.dims_size(), 0, "lengths requires at least 1 dimension");
28  uint64_t M = lengths.dims(0);
29  uint64_t indices_size = nElemFromDim(indices);
30 
31  c.flops = indices_size * D;
32  c.bytes_read = indices_size *
33  (D * sizeof(data.data_type()) + sizeof(indices.data_type())) +
34  M * sizeof(lengths.data_type());
35  c.params_bytes = N * D * sizeof(data.data_type());
36  if (use_weight) {
37  const TensorShape weights = inputs[1];
38  c.flops += indices_size * D;
39  c.bytes_read += indices_size * sizeof(weights.data_type());
40  }
41 
42  return c;
43 }
44 
45 // registering 5 input gradient with main output
46 // gradient of SparseLengthsWeightedSum
47 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient)
48  .NumInputs(5)
49  .NumOutputs(2);
50 REGISTER_CPU_OPERATOR(
51  SparseLengthsIndicesInGradientWeightedSumWithMainInputGradient,
52  AbstractLengthsWithMainInputGradientOp<
53  float,
54  float,
55  int,
56  CPUContext,
57  WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
58  true /*SparseFused*/,
59  true /*GradientNeedIndices*/>);
60 
61 // registering 4 input version
62 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientWeightedSumGradient)
63  .NumInputs(4)
64  .NumOutputs(1);
65 REGISTER_CPU_OPERATOR(
66  SparseLengthsIndicesInGradientWeightedSumGradient,
67  AbstractLengthsGradientOp<
68  float,
69  int,
70  CPUContext,
71  WeightedSumReducerDef::template ReducerGradient<float, CPUContext>,
72  true /*GradientNeedIndices*/>);
73 
74 // registering 3 input version
75 // gradient of SparseLengthsSum
76 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientSumGradient)
77  .NumInputs(3)
78  .NumOutputs(1);
79 REGISTER_CPU_OPERATOR(
80  SparseLengthsIndicesInGradientSumGradient,
81  AbstractLengthsGradientOp<
82  float,
83  int,
84  CPUContext,
85  SumReducerDef::template ReducerGradient<float, CPUContext>,
86  true /*GradientNeedIndices*/>);
87 // gradient of LengthsSum
88 OPERATOR_SCHEMA(LengthsIndicesInGradientSumGradient).NumInputs(3).NumOutputs(1);
89 REGISTER_CPU_OPERATOR(
90  LengthsIndicesInGradientSumGradient,
91  AbstractLengthsGradientOp<
92  float,
93  int,
94  CPUContext,
95  SumReducerDef::template ReducerGradient<float, CPUContext>,
96  true /*GradientNeedIndices*/>);
97 
98 // registering 3 input version
99 // gradient of SparseLengthsMean
100 OPERATOR_SCHEMA(SparseLengthsIndicesInGradientMeanGradient)
101  .NumInputs(3)
102  .NumOutputs(1);
103 REGISTER_CPU_OPERATOR(
104  SparseLengthsIndicesInGradientMeanGradient,
105  AbstractLengthsGradientOp<
106  float,
107  int,
108  CPUContext,
109  MeanReducerDef::template ReducerGradient<float, CPUContext>,
110  true /*GradientNeedIndices*/>);
111 // gradient of LengthsMean
112 OPERATOR_SCHEMA(LengthsIndicesInGradientMeanGradient)
113  .NumInputs(3)
114  .NumOutputs(1);
115 REGISTER_CPU_OPERATOR(
116  LengthsIndicesInGradientMeanGradient,
117  AbstractLengthsGradientOp<
118  float,
119  int,
120  CPUContext,
121  MeanReducerDef::template ReducerGradient<float, CPUContext>,
122  true /*GradientNeedIndices*/>);
123 
124 namespace {
125 
126 static const char* kLengthsMaxExtra = R"DOC(
127 The *LengthsMax* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the maximum value in each of the segments of *DATA*, where segments are defined by their lengths.
128 For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [max([2,4]), max([3,1,2]), max([10])] = [4,3,10]$.
129 
130 Github Link:
131 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc
132 
133 <details>
134 
135 <summary> <b>Example</b> </summary>
136 
137 **Code**
138 
139 ```
140 
141 workspace.ResetWorkspace()
142 
143 op = core.CreateOperator(
144  "LengthsMax",
145  ["DATA", "LENGTHS"],
146  ["OUTPUT"],
147 )
148 
149 workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32))
150 print("DATA:\n", workspace.FetchBlob("DATA"))
151 
152 workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32))
153 print("LENGTHS:\n", workspace.FetchBlob("LENGTHS"))
154 
155 workspace.RunOperatorOnce(op)
156 print("OUTPUT: \n", workspace.FetchBlob("OUTPUT"))
157 
158 ```
159 
160 **Result**
161 
162 ```
163 
164 DATA:
165  [ 2. 4. 3. 1. 2. 10.]
166 LENGTHS:
167  [2 3 1]
168 OUTPUT:
169  [ 4. 3. 10.]
170 
171 ```
172 
173 </details>
174 
175 )DOC";
176 
177 static const char* kLengthsMeanExtra = R"DOC(
178 The *LengthsMean* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the mean value in each of the segments of *DATA*, where segments are defined by their lengths.
179 For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [mean([2,4]), mean([3,1,2]), mean([10])] = [3,2,10]$.
180 
181 Github Link:
182 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc
183 
184 <details>
185 
186 <summary> <b>Example</b> </summary>
187 
188 **Code**
189 
190 ```
191 
192 workspace.ResetWorkspace()
193 
194 op = core.CreateOperator(
195  "LengthsMean",
196  ["DATA", "LENGTHS"],
197  ["OUTPUT"],
198 )
199 
200 workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32))
201 print("DATA:\n", workspace.FetchBlob("DATA"))
202 
203 workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32))
204 print("LENGTHS:\n", workspace.FetchBlob("LENGTHS"))
205 
206 workspace.RunOperatorOnce(op)
207 print("OUTPUT: \n", workspace.FetchBlob("OUTPUT"))
208 
209 ```
210 
211 **Result**
212 
213 ```
214 
215 DATA:
216  [ 2. 4. 3. 1. 2. 10.]
217 LENGTHS:
218  [2 3 1]
219 OUTPUT:
220  [ 3. 2. 10.]
221 
222 ```
223 
224 </details>
225 
226 )DOC";
227 
228 static const char* kLengthsSumExtra = R"DOC(
229 The *LengthsSum* op takes two inputs *DATA* and *LENGTHS*, and produces a single output *OUTPUT*. The op finds the sum in each of the segments of *DATA*, where segments are defined by their lengths.
230 For example, if $DATA = [2,4,3,1,2,10]$ and $LENGTHS = [2,3,1]$ then $OUTPUT = [sum([2,4]), sum([3,1,2]), sum([10])] = [6,6,10]$.
231 
232 Github Link:
233 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc
234 
235 <details>
236 
237 <summary> <b>Example</b> </summary>
238 
239 **Code**
240 
241 ```
242 
243 workspace.ResetWorkspace()
244 
245 op = core.CreateOperator(
246  "LengthsSum",
247  ["DATA", "LENGTHS"],
248  ["OUTPUT"],
249 )
250 
251 workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32))
252 print("DATA:\n", workspace.FetchBlob("DATA"))
253 
254 workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32))
255 print("LENGTHS:\n", workspace.FetchBlob("LENGTHS"))
256 
257 workspace.RunOperatorOnce(op)
258 print("OUTPUT: \n", workspace.FetchBlob("OUTPUT"))
259 
260 ```
261 
262 **Result**
263 
264 ```
265 
266 DATA:
267  [ 2. 4. 3. 1. 2. 10.]
268 LENGTHS:
269  [2 3 1]
270 OUTPUT:
271  [ 6. 6. 10.]
272 
273 ```
274 
275 </details>
276 
277 )DOC";
278 
279 static const char* kLengthsWeightedSumExtra = R"DOC(
280 The *LengthsWeightedSum* op takes three inputs *DATA*, *LENGTHS*, and *SCALARS*, and produces a single output *OUTPUT*. The op finds the weighted sum in each of the segments of *DATA*, where segments are defined by their lengths. Before calculating the sums, the input *DATA* is weighted by the contents of *SCALARS*.
281 For example, if $DATA = [2,4,3,1,2,10]$, $SCALARS = [8, 2, 1, 4, 1, 0.6]$, and $LENGTHS = [2,3,1]$, then $OUTPUT = [sum([8*2,2*4]), sum([1*3,4*1,1*2]), sum([0.6*10])] = [24,9,6]$.
282 
283 Github Link:
284 - https://github.com/caffe2/caffe2/blob/master/caffe2/operators/segment_reduction_op.cc
285 
286 <details>
287 
288 <summary> <b>Example</b> </summary>
289 
290 **Code**
291 
292 ```
293 
294 workspace.ResetWorkspace()
295 
296 op = core.CreateOperator(
297  "LengthsWeightedSum",
298  ["DATA", "SCALARS","LENGTHS"],
299  ["OUTPUT"],
300 )
301 
302 workspace.FeedBlob("DATA", np.array([2,4,3,1,2,10]).astype(np.float32))
303 print("DATA:\n", workspace.FetchBlob("DATA"))
304 
305 workspace.FeedBlob("SCALARS", np.array([8, 2, 1, 4, 1, 0.6]).astype(np.float32))
306 print("SCALARS:\n", workspace.FetchBlob("SCALARS"))
307 
308 workspace.FeedBlob("LENGTHS", np.array([2,3,1]).astype(np.int32))
309 print("LENGTHS:\n", workspace.FetchBlob("LENGTHS"))
310 
311 workspace.RunOperatorOnce(op)
312 print("OUTPUT: \n", workspace.FetchBlob("OUTPUT"))
313 
314 ```
315 
316 **Result**
317 
318 ```
319 
320 DATA:
321  [ 2. 4. 3. 1. 2. 10.]
322 SCALARS:
323  [8. 2. 1. 4. 1. 0.6]
324 LENGTHS:
325  [2 3 1]
326 OUTPUT:
327  [24. 9. 6.]
328 
329 ```
330 
331 </details>
332 
333 )DOC";
334 
335 template <typename Def>
336 string FormatDoc() {
337  string doc = Def::doc;
338  c10::ReplaceAll(doc, "{op}", Def::OpDef::name);
339  c10::ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
340  if (strcmp(Def::OpDef::name, "Max") == 0) {
341  c10::ReplaceAll(doc, "{extra}", kLengthsMaxExtra);
342  } else if (strcmp(Def::OpDef::name, "Mean") == 0) {
343  c10::ReplaceAll(doc, "{extra}", kLengthsMeanExtra);
344  } else if (strcmp(Def::OpDef::name, "Sum") == 0) {
345  c10::ReplaceAll(doc, "{extra}", kLengthsSumExtra);
346  } else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) {
347  c10::ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra);
348  } else {
349  c10::ReplaceAll(doc, "{extra}", " ");
350  }
351  return doc;
352 }
353 
354 // Helper function to enforce naming conventions at compile time.
355 constexpr bool equal(
356  char const* lhs,
357  char const* rhs1,
358  char const* rhs2,
359  char const* rhs3 = "") {
360  return (*lhs == 0 && *rhs1 == 0 && *rhs2 == 0 && *rhs3 == 0) ||
361  (*rhs1 != 0 && *lhs == *rhs1 && equal(lhs + 1, rhs1 + 1, rhs2, rhs3)) ||
362  (*rhs1 == 0 && *rhs2 != 0 && *lhs == *rhs2 &&
363  equal(lhs + 1, rhs1, rhs2 + 1, rhs3)) ||
364  (*rhs1 == 0 && *rhs2 == 0 && *rhs3 != 0 && *lhs == *rhs3 &&
365  equal(lhs + 1, rhs1, rhs2, rhs3 + 1));
366 }
367 
368 // Helper macro when the main op is defined elsewhere, and we only need to
369 // define the schema, and the gradient op.
370 // TODO: enable input fillers
371 #define REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY( \
372  segment_name, gradient_name, ...) \
373  static_assert( \
374  equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \
375  #segment_name); \
376  static_assert( \
377  equal( \
378  #gradient_name, \
379  __VA_ARGS__::basename, \
380  __VA_ARGS__::OpDef::name, \
381  "Gradient"), \
382  #gradient_name); \
383  OPERATOR_SCHEMA(segment_name) \
384  .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
385  .NumOutputs(1) \
386  .DisallowInputFillers() \
387  .SetDoc(FormatDoc<__VA_ARGS__>()) \
388  .Output(0, "OUTPUT", "Aggregated tensor") \
389  .FillUsing(__VA_ARGS__::PopulateSchema); \
390  REGISTER_CPU_OPERATOR_STR(string(#gradient_name), __VA_ARGS__::BackwardOp); \
391  OPERATOR_SCHEMA(gradient_name) \
392  .NumInputs(__VA_ARGS__::BackwardOp::kNumInputs) \
393  .NumOutputs(1) \
394  .DisallowInputFillers(); \
395  REGISTER_GRADIENT_STR(string(#segment_name), __VA_ARGS__::GetGradient)
396 
397 #define REGISTER_SEGMENT_DEF(segment_name, gradient_name, ...) \
398  static_assert( \
399  equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \
400  #segment_name); \
401  REGISTER_CPU_OPERATOR_STR(string(#segment_name), __VA_ARGS__::ForwardOp); \
402  REGISTER_SEGMENT_DEF_SCHEMA_GRADIENT_ONLY( \
403  segment_name, gradient_name, __VA_ARGS__)
404 
405 REGISTER_SEGMENT_DEF(
406  SortedSegmentRangeSum,
407  SortedSegmentRangeSumGradient,
408  AbstractSortedSegmentRangeDef<float, int, CPUContext, SumRangeReducerDef>);
409 REGISTER_SEGMENT_DEF(
410  SortedSegmentRangeLogSumExp,
411  SortedSegmentRangeLogSumExpGradient,
412  AbstractSortedSegmentRangeDef<
413  float,
414  int,
415  CPUContext,
416  LogSumExpRangeReducerDef>);
417 REGISTER_SEGMENT_DEF(
418  SortedSegmentRangeLogMeanExp,
419  SortedSegmentRangeLogMeanExpGradient,
420  AbstractSortedSegmentRangeDef<
421  float,
422  int,
423  CPUContext,
424  LogMeanExpRangeReducerDef>);
425 REGISTER_SEGMENT_DEF(
426  SortedSegmentRangeMean,
427  SortedSegmentRangeMeanGradient,
428  AbstractSortedSegmentRangeDef<float, int, CPUContext, MeanRangeReducerDef>);
429 REGISTER_SEGMENT_DEF(
430  SortedSegmentRangeMax,
431  SortedSegmentRangeMaxGradient,
432  AbstractSortedSegmentRangeDef<float, int, CPUContext, MaxRangeReducerDef>);
433 
434 REGISTER_SEGMENT_DEF(
435  SortedSegmentSum,
436  SortedSegmentSumGradient,
437  AbstractSortedSegmentDef<float, int, CPUContext, SumReducerDef>);
438 REGISTER_SEGMENT_DEF(
439  SparseSortedSegmentSum,
440  SparseSortedSegmentSumGradient,
441  AbstractSparseSortedSegmentDef<float, int, CPUContext, SumReducerDef>);
442 REGISTER_SEGMENT_DEF(
443  UnsortedSegmentSum,
444  UnsortedSegmentSumGradient,
445  AbstractUnsortedSegmentDef<float, int, CPUContext, SumReducerDef>);
446 REGISTER_SEGMENT_DEF(
447  SparseUnsortedSegmentSum,
448  SparseUnsortedSegmentSumGradient,
449  AbstractSparseUnsortedSegmentDef<float, int, CPUContext, SumReducerDef>);
450 
451 REGISTER_SEGMENT_DEF(
452  LengthsSum,
453  LengthsSumGradient,
454  AbstractLengthsDef<float, int, CPUContext, SumReducerDef, true>);
455 
456 REGISTER_SEGMENT_DEF(
457  SortedSegmentMean,
458  SortedSegmentMeanGradient,
459  AbstractSortedSegmentDef<float, int, CPUContext, MeanReducerDef>);
460 REGISTER_SEGMENT_DEF(
461  SparseSortedSegmentMean,
462  SparseSortedSegmentMeanGradient,
463  AbstractSparseSortedSegmentDef<float, int, CPUContext, MeanReducerDef>);
464 REGISTER_SEGMENT_DEF(
465  UnsortedSegmentMean,
466  UnsortedSegmentMeanGradient,
467  AbstractUnsortedSegmentDef<float, int, CPUContext, MeanReducerDef>);
468 REGISTER_SEGMENT_DEF(
469  SparseUnsortedSegmentMean,
470  SparseUnsortedSegmentMeanGradient,
471  AbstractSparseUnsortedSegmentDef<float, int, CPUContext, MeanReducerDef>);
472 
473 REGISTER_SEGMENT_DEF(
474  LengthsMean,
475  LengthsMeanGradient,
476  AbstractLengthsDef<float, int, CPUContext, MeanReducerDef, true>);
477 
478 REGISTER_SEGMENT_DEF(
479  ReduceFrontWeightedSum,
480  ReduceFrontWeightedSumGradient,
481  AbstractReduceFrontDef<float, CPUContext, WeightedSumReducerDef>);
482 REGISTER_SEGMENT_DEF(
483  SortedSegmentWeightedSum,
484  SortedSegmentWeightedSumGradient,
485  AbstractSortedSegmentDef<float, int, CPUContext, WeightedSumReducerDef>);
486 REGISTER_SEGMENT_DEF(
487  SparseSortedSegmentWeightedSum,
488  SparseSortedSegmentWeightedSumGradient,
489  AbstractSparseSortedSegmentDef<
490  float,
491  int,
492  CPUContext,
493  WeightedSumReducerDef>);
494 REGISTER_SEGMENT_DEF(
495  UnsortedSegmentWeightedSum,
496  UnsortedSegmentWeightedSumGradient,
497  AbstractUnsortedSegmentDef<float, int, CPUContext, WeightedSumReducerDef>);
498 REGISTER_SEGMENT_DEF(
499  SparseUnsortedSegmentWeightedSum,
500  SparseUnsortedSegmentWeightedSumGradient,
501  AbstractSparseUnsortedSegmentDef<
502  float,
503  int,
504  CPUContext,
505  WeightedSumReducerDef>);
506 REGISTER_SEGMENT_DEF(
507  LengthsWeightedSum,
508  LengthsWeightedSumGradient,
509  AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef, false>);
510 
511 // Auxiliary output gradients are currently implemented only for Lengths version
512 #define REGISTER_GRADIENT_WITH_MAIN_INPUT(gradient_name, ...) \
513  static_assert( \
514  equal( \
515  #gradient_name, \
516  __VA_ARGS__::basename, \
517  __VA_ARGS__::OpDef::name, \
518  "WithMainInputGradient"), \
519  #gradient_name); \
520  REGISTER_CPU_OPERATOR_STR( \
521  string(#gradient_name), __VA_ARGS__::WithMainInputBackwardOp); \
522  OPERATOR_SCHEMA(gradient_name) \
523  .NumInputs(__VA_ARGS__::WithMainInputBackwardOp::kNumInputs) \
524  .NumOutputs(1, INT_MAX)
525 
526 REGISTER_GRADIENT_WITH_MAIN_INPUT(
527  LengthsWeightedSumWithMainInputGradient,
528  AbstractLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
529 REGISTER_GRADIENT_WITH_MAIN_INPUT(
530  SparseLengthsWeightedSumWithMainInputGradient,
531  AbstractSparseLengthsDef<float, int, CPUContext, WeightedSumReducerDef>);
532 } // namespace
533 
534 #define REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT( \
535  gradient_name, ...) \
536  static_assert( \
537  equal( \
538  #gradient_name, \
539  __VA_ARGS__::basename, \
540  __VA_ARGS__::OpDef::name, \
541  "WithMainInputAndForwardOutputGradient"), \
542  #gradient_name); \
543  REGISTER_CPU_OPERATOR_STR( \
544  string(#gradient_name), \
545  __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp); \
546  OPERATOR_SCHEMA(gradient_name) \
547  .NumInputs( \
548  __VA_ARGS__::WithMainInputAndForwardOutputBackwardOp::kNumInputs) \
549  .NumOutputs(1, INT_MAX)
550 
551 #define REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \
552  segment_name, gradient_name, ...) \
553  static_assert( \
554  equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \
555  #segment_name); \
556  OPERATOR_SCHEMA(segment_name) \
557  .NumInputs(__VA_ARGS__::ForwardOp::kNumInputs) \
558  .NumOutputs(1) \
559  .SetDoc(FormatDoc<__VA_ARGS__>()) \
560  .Output(0, "OUTPUT", "Aggregated tensor") \
561  .FillUsing(__VA_ARGS__::PopulateSchema); \
562  REGISTER_GRADIENT_WITH_MAIN_INPUT_AND_FORWARD_OUTPUT( \
563  gradient_name, __VA_ARGS__); \
564  REGISTER_GRADIENT_STR(string(#segment_name), __VA_ARGS__::GetGradient)
565 
566 // This implements and registers a length op with a gradient which requires
567 // the main input as well as the output of the forward output.
568 #define REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \
569  segment_name, gradient_name, ...) \
570  static_assert( \
571  equal(#segment_name, __VA_ARGS__::basename, __VA_ARGS__::OpDef::name), \
572  #segment_name); \
573  REGISTER_CPU_OPERATOR_STR(string(#segment_name), __VA_ARGS__::ForwardOp); \
574  REGISTER_SEGMENT_DEF_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT( \
575  segment_name, gradient_name, __VA_ARGS__)
576 
577 REGISTER_LENGTHS_OPS_MAIN_INPUT_AND_FORWARD_OUTPUT_GRADIENT(
578  LengthsMax,
579  LengthsMaxWithMainInputAndForwardOutputGradient,
580  AbstractLengthsDef<float, int, CPUContext, MaxReducerDef>);
581 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:70