Caffe2 - C++ API
A deep learning, cross platform ML framework
pow_op.cc
1 #include "caffe2/operators/pow_op.h"
2 #include "caffe2/utils/eigen_utils.h"
3 #include "caffe2/utils/math.h"
4 // definition of NumericTypes and SameTypeAsInput is in below header file
5 //#include "caffe2/operators/elementwise_op.h"
6 #include <Eigen/Core>
7 
8 namespace caffe2 {
9 
10 #define EIGEN_POW(x, y) (x.pow(y))
11 
13  template <int b_is_scalar, typename T1, typename T2, typename R>
14  inline void
15  Run(size_t n, const T1* a, const T2* b, T2 e, R* out, CPUContext*) {
16  if (b == NULL) {
17  EigenVectorArrayMap<R>(out, n) =
18  EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (e));
19  } else {
20  if (b_is_scalar) {
21  EigenVectorArrayMap<R>(out, n) =
22  EIGEN_POW((ConstEigenVectorArrayMap<T1>(a, n)), (b[0]));
23  } else {
24  EigenVectorArrayMap<R>(out, n) = EIGEN_POW(
25  (ConstEigenVectorArrayMap<T1>(a, n)),
26  (ConstEigenVectorArrayMap<T2>(b, n)));
27  }
28  }
29  }
30  template <typename T1, typename T2, typename R>
31  void RunWithBroadcast(
32  const T1* a,
33  const T2* b,
34  R* out,
35  size_t pre,
36  size_t n,
37  CPUContext*) {
38  EigenArrayMap<R>(out, n, pre) = EIGEN_POW(
39  (ConstEigenArrayMap<T1>(a, n, pre)),
40  (ConstEigenVectorArrayMap<T2>(b, n)).rowwise().replicate(pre));
41  /*
42  //below code only allows elementary ops, such as +, -, * and /,
43  //and does not allow operations, such as pow, exp and log
44  EIGEN_POW(
45  (ConstEigenArrayMap<T>(a, n, pre).colwise()),
46  (ConstEigenVectorArrayMap<T>(b, n)));
47  */
48  }
49  template <typename T1, typename T2, typename R>
50  void RunWithBroadcast2(
51  const T1* a,
52  const T2* b,
53  R* out,
54  size_t pre,
55  size_t n,
56  size_t post,
57  CPUContext*) {
58  for (int i = 0; i < pre; ++i) {
59  EigenArrayMap<R>(out + i * n * post, post, n) = EIGEN_POW(
60  (ConstEigenArrayMap<T1>(a + i * n * post, post, n)),
61  (Eigen::Map<const Eigen::Array<T2, 1, Eigen::Dynamic>>(b, n))
62  .colwise()
63  .replicate(post));
64  /*
65  //below code only allows elementary ops, such as +, -, * and /,
66  //and does not allow for operations, such as pow, exp and log
67  EIEGN_POW(
68  (ConstEigenArrayMap<T>(a + i * n * post, post, n).rowwise()),
69  (Eigen::Map<const Eigen::Array<T, 1, Eigen::Dynamic>>(b, n)));
70  */
71  }
72  }
73 };
74 
75 REGISTER_CPU_OPERATOR(
76  Pow,
77  PowOp<
78  TensorTypes<float> /*NumericTypes*/,
79  CPUContext,
82 
83 OPERATOR_SCHEMA(Pow)
84  .NumInputs(1, 2)
85  .NumOutputs(1)
86  .AllowInplace({{0, 0}, {1, 0}})
87  .IdenticalTypeAndShapeOfInput(0)
88  .SetDoc(R"DOC(
89 The *Pow* op takes an input data tensor $X$ and an exponent parameter *exponent*, which can be a scalar or another tensor. As output, it produces a single output data tensor $Y$, where the function $f(x) = x^{exponent}$ has been applied to $X$ elementwise.
90 
91 Github Links:
92 
93 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/pow_op.h
94 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/pow_op.cc
95 
96 
97 <details>
98 
99 <summary> <b>Example</b> </summary>
100 
101 **Code**
102 
103 ```
104 
105 workspace.ResetWorkspace()
106 
107 op = core.CreateOperator(
108  "Pow",
109  ["X", "exponent"],
110  ["Y"],
111  broadcast=1
112 )
113 
114 workspace.FeedBlob("X", np.array([1,2,3,4,5,6]).astype(np.float32))
115 print("X: ", workspace.FetchBlob("X"))
116 
117 workspace.FeedBlob("exponent", np.array([2]).astype(np.float32))
118 print("exponent: ", workspace.FetchBlob("exponent"))
119 
120 workspace.RunOperatorOnce(op)
121 print("Y: ", workspace.FetchBlob("Y"))
122 
123 ```
124 
125 **Result**
126 
127 ```
128 
129 X: [1. 2. 3. 4. 5. 6.]
130 exponent: [2.]
131 Y: [ 1. 4. 9. 16. 25. 36.]
132 
133 ```
134 
135 </details>
136 
137 
138 )DOC")
139  .Input(0, "X", "Input data blob to be operated on.")
140  .Input(1, "exponent", "Exponent blob containing the exponent(s) for calculation. Do not use if setting exponent via argument.")
141  .Output(0, "Y", "Output data blob with the same shape as the input.")
142  .Arg("exponent", "The exponent of the power function. Do not use if setting exponent via input.")
143  .Arg("axis", "*(type: int; default: -1)*")
144  .Arg("broadcast", "*(type: bool; default: False)*");
145 
146 class GetPowGradient : public GradientMakerBase {
147  using GradientMakerBase::GradientMakerBase;
148  vector<OperatorDef> GetGradientDefs() override {
149  ArgumentHelper arg_helper(def_);
150  if (arg_helper.HasArgument("exponent")) { // second input is a scalar
151  // function f(w,a) = w^a
152  // gradient operator with respect to first input tensor
153  // df/dw = a * w^(a-1) (all operations are component-wise)
154  float exponent = arg_helper.GetSingleArgument<float>("exponent", 0.0);
155  Argument scale_arg;
156  scale_arg.set_name("scale");
157  scale_arg.set_f(exponent);
158  Argument pow_arg;
159  pow_arg.set_name("exponent");
160  if (I(0) != O(0)) {
161  pow_arg.set_f(exponent - 1);
162  } else {
163  LOG(WARNING) << "In-place Pow gradient, possible loss of precision";
164  constexpr float kEps = 1e-12f;
165  CAFFE_ENFORCE(std::fabs(exponent) > kEps);
166  pow_arg.set_f((exponent - 1) / exponent);
167  }
168  return vector<OperatorDef>{CreateOperatorDef(
169  "Pow",
170  "",
171  std::vector<string>{I(0)},
172  std::vector<string>{GI(0)},
173  std::vector<Argument>{pow_arg}),
174  CreateOperatorDef(
175  "Mul",
176  "",
177  std::vector<string>{GI(0), GO(0)},
178  std::vector<string>{GI(0)}),
179  CreateOperatorDef(
180  "Scale",
181  "",
182  std::vector<string>{GI(0)},
183  std::vector<string>{GI(0)},
184  std::vector<Argument>{scale_arg})};
185  /*
186  // Alternative gradient computation
187  return vector<OperatorDef>{CreateOperatorDef(
188  "Div",
189  "",
190  std::vector<string>{O(0), I(0)},
191  std::vector<string>{GI(0)}),
192  CreateOperatorDef(
193  "Mul",
194  "",
195  std::vector<string>{GI(0), GO(0)},
196  std::vector<string>{GI(0)}),
197  CreateOperatorDef(
198  "Scale",
199  "",
200  std::vector<string>{GI(0)},
201  std::vector<string>{GI(0)},
202  std::vector<Argument>{scale_arg})};
203  */
204  } else { // second input is a tensor
205  CAFFE_ENFORCE(
206  Def().input(0) != Def().output(0) &&
207  Def().input(1) != Def().output(0),
208  "Gradient computation cannot be carried out if Pow uses in-place "
209  "computation: ",
210  ProtoDebugString(Def()));
211  vector<OperatorDef> grad_ops;
212  Argument one_arg;
213  one_arg.set_name("value");
214  one_arg.set_f(1);
215  Argument broadcast, axis, axis_str, order;
216  bool bflag = ArgumentHelper::HasArgument(Def(), "broadcast");
217 
218  if (bflag) {
219  if (ArgumentHelper::HasArgument(Def(), "broadcast")) {
220  broadcast = GetArgument(Def(), "broadcast");
221  } else {
222  broadcast = MakeArgument<int>("broadcast", 0);
223  }
224  if (ArgumentHelper::HasArgument(Def(), "axis")) {
225  axis = GetArgument(Def(), "axis");
226  } else {
227  axis = MakeArgument<int>("axis", -1);
228  }
229  if (ArgumentHelper::HasArgument(Def(), "axis_str")) {
230  axis_str = GetArgument(Def(), "axis_str");
231  } else {
232  axis_str = MakeArgument<string>("axis_str", "");
233  }
234  if (ArgumentHelper::HasArgument(Def(), "order")) {
235  order = GetArgument(Def(), "order");
236  } else {
237  order = MakeArgument<string>("order", "NCHW");
238  }
239  }
240 
241  // function f(w,a) = w^a
242  // gradient operator with respect to first input tensor
243  // df/dw = a * w^(a-1) (all operations are component-wise)
244  grad_ops.push_back(CreateOperatorDef(
245  "ConstantFill",
246  "",
247  std::vector<string>{I(1)},
248  std::vector<string>{GI(1)},
249  std::vector<Argument>{one_arg}));
250  grad_ops.push_back(CreateOperatorDef(
251  "Sub",
252  "",
253  std::vector<string>{I(1), GI(1)},
254  std::vector<string>{GI(1)}));
255  if (bflag) {
256  grad_ops.push_back(CreateOperatorDef(
257  "Pow",
258  "",
259  std::vector<string>{I(0), GI(1)},
260  std::vector<string>{GI(0)},
261  vector<Argument>{broadcast, axis, axis_str, order}));
262  } else {
263  grad_ops.push_back(CreateOperatorDef(
264  "Pow",
265  "",
266  std::vector<string>{I(0), GI(1)},
267  std::vector<string>{GI(0)}));
268  }
269 
270  grad_ops.push_back(CreateOperatorDef(
271  "Mul",
272  "",
273  std::vector<string>{GI(0), GO(0)},
274  std::vector<string>{GI(0)}));
275  if (bflag) {
276  grad_ops.push_back(CreateOperatorDef(
277  "Mul",
278  "",
279  std::vector<string>{GI(0), I(1)},
280  std::vector<string>{GI(0)},
281  vector<Argument>{broadcast, axis, axis_str, order}));
282  } else {
283  grad_ops.push_back(CreateOperatorDef(
284  "Mul",
285  "",
286  std::vector<string>{GI(0), I(1)},
287  std::vector<string>{GI(0)}));
288  }
289  /*
290  // Alternative gradient computation (no broadcast support)
291  grad_ops.push_back(CreateOperatorDef(
292  "Div",
293  "",
294  std::vector<string>{O(0), I(0)},
295  std::vector<string>{GI(0)}));
296  grad_ops.push_back(CreateOperatorDef(
297  "Mul",
298  "",
299  std::vector<string>{GI(0), GO(0)},
300  std::vector<string>{GI(0)}));
301  grad_ops.push_back(CreateOperatorDef(
302  "Mul",
303  "",
304  std::vector<string>{GI(0), I(1)},
305  std::vector<string>{GI(0)}));
306  */
307  // gradient operator for with respect to second input tensor
308  // df/da = w^a * ln w (all operations are component-wise)
309  /*
310  // reset GI(1) to zero
311  Argument zero_arg;
312  zero_arg.set_name("value");
313  zero_arg.set_f(0);
314  grad_ops.push_back(CreateOperatorDef(
315  "ConstantFill",
316  "",
317  std::vector<string>{I(1)},
318  std::vector<string>{GI(1)},
319  std::vector<Argument>{zero_arg}));
320  */
321  grad_ops.push_back(CreateOperatorDef(
322  "Log",
323  "",
324  std::vector<string>{I(0)},
325  std::vector<string>{GI(1) + "_autogen_pre_red"}));
326  grad_ops.push_back(CreateOperatorDef(
327  "Mul",
328  "",
329  std::vector<string>{GI(1) + "_autogen_pre_red", O(0)},
330  std::vector<string>{GI(1) + "_autogen_pre_red"}));
331  if (bflag) {
332  grad_ops.push_back(CreateOperatorDef(
333  "Mul",
334  "",
335  std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
336  std::vector<string>{GI(1) + "_autogen_pre_red"}));
337  grad_ops.push_back(CreateOperatorDef(
338  "SumReduceLike",
339  "",
340  vector<string>{GI(1) + "_autogen_pre_red", I(1)},
341  vector<string>{GI(1)},
342  vector<Argument>{axis, axis_str, order}));
343  } else {
344  grad_ops.push_back(CreateOperatorDef(
345  "Mul",
346  "",
347  std::vector<string>{GI(1) + "_autogen_pre_red", GO(0)},
348  std::vector<string>{GI(1)}));
349  }
350 
351  return grad_ops;
352  }
353  }
354 
355  // Argument `shape` is no longer needed in backprop.
356  bool CopyArguments() const override {
357  return false;
358  }
359 };
360 
361 REGISTER_GRADIENT(Pow, GetPowGradient);
362 
363 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A helper class to index into arguments.
Definition: proto_utils.h:200
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13