1 #include "caffe2/operators/pow_op.h" 2 #include "caffe2/utils/eigen_utils.h" 3 #include "caffe2/utils/math.h" 10 #define EIGEN_POW(x, y) (x.pow(y)) 13 template <
int b_is_scalar,
typename T1,
typename T2,
typename R>
15 Run(
size_t n,
const T1* a,
const T2* b, T2 e, R* out,
CPUContext*) {
17 EigenVectorArrayMap<R>(out, n) =
18 EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (e));
21 EigenVectorArrayMap<R>(out, n) =
22 EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (b[0]));
24 EigenVectorArrayMap<R>(out, n) = EIGEN_POW(
25 (ConstEigenVectorArrayMap<T1>(a, n)),
26 (ConstEigenVectorArrayMap<T2>(b, n)));
30 template <
typename T1,
typename T2,
typename R>
31 void RunWithBroadcast(
38 EigenArrayMap<R>(out, n, pre) = EIGEN_POW(
39 (ConstEigenArrayMap<T1>(a, n, pre)),
40 (ConstEigenVectorArrayMap<T2>(b, n)).rowwise().replicate(pre));
49 template <
typename T1,
typename T2,
typename R>
50 void RunWithBroadcast2(
58 for (
int i = 0; i < pre; ++i) {
59 EigenArrayMap<R>(out + i * n * post, post, n) = EIGEN_POW(
60 (ConstEigenArrayMap<T1>(a + i * n * post, post, n)),
61 (Eigen::Map<const Eigen::Array<T2, 1, Eigen::Dynamic>>(b, n))
75 REGISTER_CPU_OPERATOR(
86 .AllowInplace({{0, 0}, {1, 0}})
87 .IdenticalTypeAndShapeOfInput(0)
89 The *Pow* op takes an input data tensor $X$ and an exponent parameter *exponent*, which can be a scalar or another tensor. As output, it produces a single output data tensor $Y$, where the function $f(x) = x^{exponent}$ has been applied to $X$ elementwise. 93 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/pow_op.h 94 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/pow_op.cc 99 <summary> <b>Example</b> </summary> 105 workspace.ResetWorkspace() 107 op = core.CreateOperator( 114 workspace.FeedBlob("X", np.array([1,2,3,4,5,6]).astype(np.float32)) 115 print("X: ", workspace.FetchBlob("X")) 117 workspace.FeedBlob("exponent", np.array([2]).astype(np.float32)) 118 print("exponent: ", workspace.FetchBlob("exponent")) 120 workspace.RunOperatorOnce(op) 121 print("Y: ", workspace.FetchBlob("Y")) 129 X: [1. 2. 3. 4. 5. 6.] 131 Y: [ 1. 4. 9. 16. 25. 36.] 139 .Input(0, "X",
"Input data blob to be operated on.")
140 .Input(1,
"exponent",
"Exponent blob containing the exponent(s) for calculation. Do not use if setting exponent via argument.")
141 .Output(0,
"Y",
"Output data blob with the same shape as the input.")
142 .Arg(
"exponent",
"The exponent of the power function. Do not use if setting exponent via input.")
143 .Arg(
"axis",
"*(type: int; default: -1)*")
144 .Arg(
"broadcast",
"*(type: bool; default: False)*");
147 using GradientMakerBase::GradientMakerBase;
148 vector<OperatorDef> GetGradientDefs()
override {
150 if (arg_helper.HasArgument(
"exponent")) {
154 float exponent = arg_helper.GetSingleArgument<
float>(
"exponent", 0.0);
156 scale_arg.set_name(
"scale");
157 scale_arg.set_f(exponent);
159 pow_arg.set_name(
"exponent");
161 pow_arg.set_f(exponent - 1);
163 LOG(WARNING) <<
"In-place Pow gradient, possible loss of precision";
164 constexpr
float kEps = 1e-12f;
165 CAFFE_ENFORCE(std::fabs(exponent) > kEps);
166 pow_arg.set_f((exponent - 1) / exponent);
168 return vector<OperatorDef>{CreateOperatorDef(
171 std::vector<string>{I(0)},
172 std::vector<string>{GI(0)},
173 std::vector<Argument>{pow_arg}),
177 std::vector<string>{GI(0), GO(0)},
178 std::vector<string>{GI(0)}),
182 std::vector<string>{GI(0)},
183 std::vector<string>{GI(0)},
184 std::vector<Argument>{scale_arg})};
206 Def().input(0) != Def().output(0) &&
207 Def().input(1) != Def().output(0),
208 "Gradient computation cannot be carried out if Pow uses in-place " 210 ProtoDebugString(Def()));
211 vector<OperatorDef> grad_ops;
213 one_arg.set_name(
"value");
215 Argument broadcast, axis, axis_str, order;
216 bool bflag = ArgumentHelper::HasArgument(Def(),
"broadcast");
219 if (ArgumentHelper::HasArgument(Def(),
"broadcast")) {
220 broadcast = GetArgument(Def(),
"broadcast");
222 broadcast = MakeArgument<int>(
"broadcast", 0);
224 if (ArgumentHelper::HasArgument(Def(),
"axis")) {
225 axis = GetArgument(Def(),
"axis");
227 axis = MakeArgument<int>(
"axis", -1);
229 if (ArgumentHelper::HasArgument(Def(),
"axis_str")) {
230 axis_str = GetArgument(Def(),
"axis_str");
232 axis_str = MakeArgument<string>(
"axis_str",
"");
234 if (ArgumentHelper::HasArgument(Def(),
"order")) {
235 order = GetArgument(Def(),
"order");
237 order = MakeArgument<string>(
"order",
"NCHW");
244 grad_ops.push_back(CreateOperatorDef(
247 std::vector<string>{I(1)},
248 std::vector<string>{GI(1)},
249 std::vector<Argument>{one_arg}));
250 grad_ops.push_back(CreateOperatorDef(
253 std::vector<string>{I(1), GI(1)},
254 std::vector<string>{GI(1)}));
256 grad_ops.push_back(CreateOperatorDef(
259 std::vector<string>{I(0), GI(1)},
260 std::vector<string>{GI(0)},
261 vector<Argument>{broadcast, axis, axis_str, order}));
263 grad_ops.push_back(CreateOperatorDef(
266 std::vector<string>{I(0), GI(1)},
267 std::vector<string>{GI(0)}));
270 grad_ops.push_back(CreateOperatorDef(
273 std::vector<string>{GI(0), GO(0)},
274 std::vector<string>{GI(0)}));
276 grad_ops.push_back(CreateOperatorDef(
279 std::vector<string>{GI(0), I(1)},
280 std::vector<string>{GI(0)},
281 vector<Argument>{broadcast, axis, axis_str, order}));
283 grad_ops.push_back(CreateOperatorDef(
286 std::vector<string>{GI(0), I(1)},
287 std::vector<string>{GI(0)}));
321 grad_ops.push_back(CreateOperatorDef(
324 std::vector<string>{I(0)},
325 std::vector<string>{GI(1) +
"_autogen_pre_red"}));
326 grad_ops.push_back(CreateOperatorDef(
329 std::vector<string>{GI(1) +
"_autogen_pre_red", O(0)},
330 std::vector<string>{GI(1) +
"_autogen_pre_red"}));
332 grad_ops.push_back(CreateOperatorDef(
335 std::vector<string>{GI(1) +
"_autogen_pre_red", GO(0)},
336 std::vector<string>{GI(1) +
"_autogen_pre_red"}));
337 grad_ops.push_back(CreateOperatorDef(
340 vector<string>{GI(1) +
"_autogen_pre_red", I(1)},
341 vector<string>{GI(1)},
342 vector<Argument>{axis, axis_str, order}));
344 grad_ops.push_back(CreateOperatorDef(
347 std::vector<string>{GI(1) +
"_autogen_pre_red", GO(0)},
348 std::vector<string>{GI(1)}));
356 bool CopyArguments()
const override {
361 REGISTER_GRADIENT(Pow, GetPowGradient);
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A helper class to index into arguments.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...