1 #include "caffe2/operators/lengths_reducer_ops.h" 2 #include "caffe2/core/context.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/operators/segment_reduction_op.h" 5 #include "caffe2/utils/math.h" 13 using SparseLengthsSumOp =
14 CPUSparseLengthsReductionOp<float, TensorTypes<float, at::Half>, 0, 0>;
15 using SparseLengthsWeightedSumOp =
16 CPUSparseLengthsReductionOp<float, TensorTypes<float, at::Half>, 1, 0>;
17 using SparseLengthsMeanOp =
18 CPUSparseLengthsReductionOp<float, TensorTypes<float, at::Half>, 0, 1>;
19 REGISTER_CPU_OPERATOR(SparseLengthsSum, SparseLengthsSumOp);
20 REGISTER_CPU_OPERATOR(SparseLengthsWeightedSum, SparseLengthsWeightedSumOp);
21 REGISTER_CPU_OPERATOR(SparseLengthsMean, SparseLengthsMeanOp);
23 OPERATOR_SCHEMA(SparseLengthsPositionalWeightedSum)
27 Variation of SparseLengthsWeightedSum operator, where, for each row, 28 weights are accessed by indices [0..L-1], where L is the length of given row. 29 This is basically a fused operator of LengthsRangeFill + Gather + 35 "uint8 tensor obtained with " 36 "operator FloatToRowwiseQuantized8Bits")
40 "Scalar multipliers for the input slices. Must " 41 "be a vector with the length matching the length of DATA")
45 "Integer vector containing indices of the first " 46 "dimension of DATA for the slices that are being aggregated")
50 "Vector with the same sum of elements as the first dimension of DATA")
51 .Output(0,
"output",
"output");
53 REGISTER_CPU_OPERATOR_STR(
54 "SparseLengthsPositionalWeightedSum",
55 CPUSparseLengthsReductionOp<
float, TensorTypes<float, at::Half>, 1, 0, 1>);
57 template <
typename Def>
59 string doc = Def::doc;
60 c10::ReplaceAll(doc,
"{op}", Def::OpDef::name);
61 c10::ReplaceAll(doc,
"{op_doc}", Def::OpDef::doc);
62 auto replaced = c10::ReplaceAll(doc,
"{extra}",
"");
63 CAFFE_ENFORCE_EQ(replaced, 0);
67 using SparseLengthsSumDef = AbstractSparseLengthsDef<
73 OPERATOR_SCHEMA(SparseLengthsSum)
74 .NumInputs(SparseLengthsSumDef::ForwardOp::kNumInputs)
76 .ValueKeyLengthInputFillers(
77 SparseLengthsSumOp::DATA,
78 SparseLengthsSumOp::INDICES,
79 SparseLengthsSumOp::LENGTHS)
80 .SetDoc(FormatDoc<SparseLengthsSumDef>())
81 .Output(0,
"OUTPUT",
"Aggregated tensor")
82 .FillUsing(SparseLengthsSumDef::PopulateSchema)
84 REGISTER_CPU_OPERATOR(
85 SparseLengthsSumGradient,
86 SparseLengthsSumDef::BackwardOp);
87 OPERATOR_SCHEMA(SparseLengthsSumGradient)
88 .NumInputs(SparseLengthsSumDef::BackwardOp::kNumInputs)
90 .DisallowInputFillers();
91 REGISTER_GRADIENT(SparseLengthsSum, SparseLengthsSumDef::GetGradient)
93 using SparseLengthsWeightedSumDef = AbstractSparseLengthsDef<
97 WeightedSumReducerDef,
99 OPERATOR_SCHEMA(SparseLengthsWeightedSum)
100 .NumInputs(SparseLengthsWeightedSumDef::ForwardOp::kNumInputs)
102 .WeightedValueKeyLengthInputFillers(
103 SparseLengthsWeightedSumOp::DATA,
104 SparseLengthsWeightedSumOp::INDICES,
105 SparseLengthsWeightedSumOp::LENGTHS,
106 SparseLengthsWeightedSumOp::WEIGHT)
107 .SetDoc(FormatDoc<SparseLengthsWeightedSumDef>())
108 .Output(0, "OUTPUT", "Aggregated tensor")
109 .FillUsing(SparseLengthsWeightedSumDef::PopulateSchema)
110 .InheritOnnxSchema();
111 REGISTER_CPU_OPERATOR(
112 SparseLengthsWeightedSumGradient,
113 SparseLengthsWeightedSumDef::BackwardOp);
114 OPERATOR_SCHEMA(SparseLengthsWeightedSumGradient)
115 .NumInputs(SparseLengthsWeightedSumDef::BackwardOp::kNumInputs)
117 .DisallowInputFillers();
119 SparseLengthsWeightedSum,
120 SparseLengthsWeightedSumDef::GetGradient)
122 using SparseLengthsMeanDef = AbstractSparseLengthsDef<
128 OPERATOR_SCHEMA(SparseLengthsMean)
129 .NumInputs(SparseLengthsMeanDef::ForwardOp::kNumInputs)
131 .ValueKeyLengthInputFillers(
132 SparseLengthsMeanOp::DATA,
133 SparseLengthsMeanOp::INDICES,
134 SparseLengthsMeanOp::LENGTHS)
135 .SetDoc(FormatDoc<SparseLengthsMeanDef>())
136 .Output(0, "OUTPUT", "Aggregated tensor")
137 .FillUsing(SparseLengthsMeanDef::PopulateSchema);
138 REGISTER_CPU_OPERATOR(
139 SparseLengthsMeanGradient,
140 SparseLengthsMeanDef::BackwardOp);
141 OPERATOR_SCHEMA(SparseLengthsMeanGradient)
142 .NumInputs(SparseLengthsMeanDef::BackwardOp::kNumInputs)
144 .DisallowInputFillers();
145 REGISTER_GRADIENT(SparseLengthsMean, SparseLengthsMeanDef::GetGradient)
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...