Caffe2 - C++ API
A deep learning, cross platform ML framework
lengths_reducer_fused_8bit_rowwise_ops.cc
1 #include "caffe2/operators/lengths_reducer_fused_8bit_rowwise_ops.h"
2 #include "c10/util/Registry.h"
3 
4 namespace caffe2 {
5 
6 REGISTER_CPU_OPERATOR(
7  SparseLengthsSumFused8BitRowwise,
8  SparseLengthsFused8BitRowwiseOp<CPUContext>);
9 OPERATOR_SCHEMA(SparseLengthsSumFused8BitRowwise)
10  .NumInputs(3)
11  .NumOutputs(1)
12  .ValueKeyLengthInputFillers(
13  SparseLengthsFused8BitRowwiseOp<CPUContext>::DATA,
14  SparseLengthsFused8BitRowwiseOp<CPUContext>::INDICES,
15  SparseLengthsFused8BitRowwiseOp<CPUContext>::LENGTHS)
16  .SetDoc(R"DOC(
17 Performs the same operation as SparseLengthsSum, but operating on
18 8-bit rowwise quantized matrices with fused storage (where each row
19 stores quantized values, and then 4-byte scale and 4-byte bias).
20 )DOC")
21  .Input(
22  0,
23  "DATA",
24  "uint8 tensor obtained with "
25  "operator FloatToFused8BitRowwiseQuantized")
26  .Input(
27  1,
28  "INDICES",
29  "Integer vector containing indices of the first "
30  "dimension of DATA for the slices that are being aggregated")
31  .Input(
32  2,
33  "LENGTHS",
34  "Vector with the same sum of elements as the first dimension of DATA")
35  .Output(0, "output", "output")
36  .InheritOnnxSchema();
37 NO_GRADIENT(SparseLengthsSumFused8BitRowwise);
38 
39 REGISTER_CPU_OPERATOR(
40  SparseLengthsWeightedSumFused8BitRowwise,
41  SparseLengthsFused8BitRowwiseOp<CPUContext, /*with_weights=*/true>);
42 OPERATOR_SCHEMA(SparseLengthsWeightedSumFused8BitRowwise)
43  .NumInputs(4)
44  .NumOutputs(1)
45  .WeightedValueKeyLengthInputFillers(
46  SparseLengthsFused8BitRowwiseOp<CPUContext, true>::DATA,
47  SparseLengthsFused8BitRowwiseOp<CPUContext, true>::INDICES,
48  SparseLengthsFused8BitRowwiseOp<CPUContext, true>::LENGTHS,
49  SparseLengthsFused8BitRowwiseOp<CPUContext, true>::WEIGHTS)
50  .SetDoc(R"DOC(
51 Performs the same operation as SparseLengthsWeightedSum,
52 but operating on 8-bit rowwise quantized matrices with fused storage
53 (where each row stores quantized values, and then 4-byte scale and 4-byte bias).
54 )DOC")
55  .Input(
56  0,
57  "DATA",
58  "uint8 tensor obtained with "
59  "operator FloatToFused8BitRowwiseQuantized")
60  .Input(
61  1,
62  "INDICES",
63  "Integer vector containing indices of the first "
64  "dimension of DATA for the slices that are being aggregated")
65  .Input(
66  2,
67  "LENGTHS",
68  "Vector with the same sum of elements as the first dimension of DATA")
69  .Input(
70  3,
71  "WEIGHTS",
72  "Vector of weights to scale rows of DATA with before reduction")
73  .Output(0, "output", "output");
74 
75 NO_GRADIENT(SparseLengthsWeightedSumFused8BitRowwise);
76 
77 REGISTER_CPU_OPERATOR(
78  SparseLengthsMeanFused8BitRowwise,
79  SparseLengthsFused8BitRowwiseOp<
80  CPUContext,
81  /*with_weights=*/false,
82  /*is_mean=*/true>);
83 OPERATOR_SCHEMA(SparseLengthsMeanFused8BitRowwise)
84  .NumInputs(3)
85  .NumOutputs(1)
86  .ValueKeyLengthInputFillers(
87  SparseLengthsFused8BitRowwiseOp<CPUContext, false, true>::DATA,
88  SparseLengthsFused8BitRowwiseOp<CPUContext, false, true>::INDICES,
89  SparseLengthsFused8BitRowwiseOp<CPUContext, false, true>::LENGTHS)
90  .SetDoc(R"DOC(
91 Performs the same operation as SparseLengthsMean, but
92 operating on 8-bit rowwise quantized matrices with fused storage
93 (where each row stores quantized values, and then 4-byte scale and 4-byte bias).
94 )DOC")
95  .Input(
96  0,
97  "DATA",
98  "uint8 tensor obtained with "
99  "operator FloatToFused8BitRowwiseQuantized")
100  .Input(
101  1,
102  "INDICES",
103  "Integer vector containing indices of the first "
104  "dimension of DATA for the slices that are being aggregated")
105  .Input(
106  2,
107  "LENGTHS",
108  "Vector with the same sum of elements as the first dimension of DATA")
109  .Output(0, "output", "output");
110 NO_GRADIENT(SparseLengthsMeanFused8BitRowwise);
111 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13