Caffe2 - C++ API
A deep learning, cross platform ML framework
pow_op.h
1 #ifndef CAFFE2_OPERATORS_POW_OP_H_
2 #define CAFFE2_OPERATORS_POW_OP_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/operators/elementwise_ops.h"
9 #include "caffe2/operators/elementwise_ops_utils.h"
10 #include "caffe2/utils/math.h"
11 
12 namespace caffe2 {
13 
14 template <
15  typename InputTypes,
16  class Context,
17  class Functor,
18  class TypeMap = SameTypeAsInput>
19 class PowOp : public Operator<Context> {
20  public:
21  USE_OPERATOR_CONTEXT_FUNCTIONS;
22 
23  template <class... Args>
24  explicit PowOp(Args&&... args)
25  : Operator<Context>(std::forward<Args>(args)...),
26  OP_SINGLE_ARG(bool, "broadcast", enable_broadcast_, 0),
27  OP_SINGLE_ARG(int, "axis", axis_, -1),
28  OP_SINGLE_ARG(string, "axis_str", axis_str_, ""),
29  OP_SINGLE_ARG(string, "order", order_, "NCHW"),
30  functor_() {
31  if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
32  exponent_ = this->template GetSingleArgument<float>(
33  "exponent", 0); // based on pow_ops.h
34  } else if (InputSize() == 2) { // BinaryElementwiseOp
35  // Figure out the correct axis to use.
36  if (enable_broadcast_) {
37  if (axis_ != -1) {
38  // Get axis from an explicit axis argument.
39  CAFFE_ENFORCE_EQ(
40  axis_str_.size(),
41  0,
42  "Args axis and axis_str cannot be used simultaneously.");
43  } else if (axis_str_.size()) {
44  // Get the axis index semantically.
45  CAFFE_ENFORCE_EQ(
46  axis_str_.size(), 1, "Unsupported axis string", axis_str_);
47  size_t semantic_axis_ = order_.find(axis_str_);
48  CAFFE_ENFORCE_NE(
49  semantic_axis_,
50  string::npos,
51  "Unrecognizable axis string ",
52  axis_str_,
53  " from order string ",
54  order_);
55  axis_ = semantic_axis_;
56  }
57  } else {
58  CAFFE_ENFORCE(
59  axis_ == -1 && axis_str_.size() == 0,
60  "Do not specify axis or axis_str if broadcast is not enabled.");
61  }
62  } else {
63  CAFFE_THROW(
64  "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
65  }
66  }
67 
68  bool RunOnDevice() override {
69  return DispatchHelper<InputTypes>::call(this, Input(0));
70  }
71 
72  template <typename T>
73  bool DoRunWithType() {
74  if ((InputSize() == 1) && HasArgument("exponent")) { // UnaryElementwiseOp
75  const auto& A = Input(0);
76 
77  auto* C =
78  Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>());
79  const T* Adata = A.template data<T>();
80  auto* Cdata =
81  C->template mutable_data<typename TypeMap::template type<T>>();
82  functor_.template Run<true, T, float, T>(
83  A.numel(), Adata, NULL, exponent_, Cdata, &context_);
84  } else if (InputSize() == 2) { // BinaryElementwiseOp
85  const auto& A = Input(0);
86  const auto& B = Input(1);
87  CAFFE_ENFORCE(
88  !IsInputOutputAlias(1, 0) || !enable_broadcast_,
89  "In-place is allowed only with the first tensor when broadcasting");
90  auto* C =
91  Output(0, A.sizes(), at::dtype<typename TypeMap::template type<T>>());
92  const T* Adata = A.template data<T>();
93  const T* Bdata = B.template data<T>();
94  auto* Cdata =
95  C->template mutable_data<typename TypeMap::template type<T>>();
96  if (!enable_broadcast_) {
97  CAFFE_ENFORCE_EQ(
98  A.sizes(),
99  B.sizes(),
100  "Dimension mismatch - did you forget to set broadcast=1?");
101  functor_.template Run<false, T, T, T>(
102  A.numel(), Adata, Bdata, 0, Cdata, &context_);
103  } else if (B.numel() == 1) {
104  functor_.template Run<true, T, T, T>(
105  A.numel(), Adata, Bdata, 0, Cdata, &context_);
106  } else {
107  size_t pre, n, post;
108  std::tie(pre, n, post) =
109  elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
110  if (post == 1) {
111  functor_.template RunWithBroadcast<T, T, T>(
112  Adata, Bdata, Cdata, pre, n, &context_);
113  } else {
114  functor_.template RunWithBroadcast2<T, T, T>(
115  Adata, Bdata, Cdata, pre, n, post, &context_);
116  }
117  }
118  } else {
119  CAFFE_THROW(
120  "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
121  }
122  return true;
123  }
124 
125  private:
126  bool enable_broadcast_;
127  int axis_;
128  string axis_str_;
129  string order_;
130  float exponent_;
131  Functor functor_;
132 };
133 
134 } // namespace caffe2
135 
136 #endif // CAFFE2_OPERATORS_POW_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
Definition: static.cpp:52
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70