1 #ifndef CAFFE2_OPERATORS_POW_OP_H_ 2 #define CAFFE2_OPERATORS_POW_OP_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/operators/elementwise_ops.h" 9 #include "caffe2/operators/elementwise_ops_utils.h" 10 #include "caffe2/utils/math.h" 18 class TypeMap = SameTypeAsInput>
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
23 template <
class... Args>
24 explicit PowOp(Args&&... args)
26 OP_SINGLE_ARG(
bool,
"broadcast", enable_broadcast_, 0),
27 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
28 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
""),
29 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW"),
31 if ((InputSize() == 1) &&
HasArgument(
"exponent")) {
32 exponent_ = this->
template GetSingleArgument<float>(
34 }
else if (InputSize() == 2) {
36 if (enable_broadcast_) {
42 "Args axis and axis_str cannot be used simultaneously.");
43 }
else if (axis_str_.size()) {
46 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
47 size_t semantic_axis_ = order_.find(axis_str_);
51 "Unrecognizable axis string ",
53 " from order string ",
55 axis_ = semantic_axis_;
59 axis_ == -1 && axis_str_.size() == 0,
60 "Do not specify axis or axis_str if broadcast is not enabled.");
64 "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
68 bool RunOnDevice()
override {
73 bool DoRunWithType() {
74 if ((InputSize() == 1) &&
HasArgument(
"exponent")) {
78 Output(0,
A.sizes(), at::dtype<typename TypeMap::template type<T>>());
79 const T* Adata =
A.template data<T>();
81 C->template mutable_data<typename TypeMap::template type<T>>();
82 functor_.template Run<true, T, float, T>(
83 A.numel(), Adata, NULL, exponent_, Cdata, &context_);
84 }
else if (InputSize() == 2) {
88 !IsInputOutputAlias(1, 0) || !enable_broadcast_,
89 "In-place is allowed only with the first tensor when broadcasting");
91 Output(0,
A.sizes(), at::dtype<typename TypeMap::template type<T>>());
92 const T* Adata =
A.template data<T>();
93 const T* Bdata =
B.template data<T>();
95 C->template mutable_data<typename TypeMap::template type<T>>();
96 if (!enable_broadcast_) {
100 "Dimension mismatch - did you forget to set broadcast=1?");
101 functor_.template Run<false, T, T, T>(
102 A.numel(), Adata, Bdata, 0, Cdata, &context_);
103 }
else if (
B.numel() == 1) {
104 functor_.template Run<true, T, T, T>(
105 A.numel(), Adata, Bdata, 0, Cdata, &context_);
108 std::tie(pre, n, post) =
109 elementwise_ops_utils::ComputeLegacyBroadcastSizes(
A,
B, axis_);
111 functor_.template RunWithBroadcast<T, T, T>(
112 Adata, Bdata, Cdata, pre, n, &context_);
114 functor_.template RunWithBroadcast2<T, T, T>(
115 Adata, Bdata, Cdata, pre, n, post, &context_);
120 "Only a tensor with an argument or two input tensors are supported as input to pow operator.");
126 bool enable_broadcast_;
136 #endif // CAFFE2_OPERATORS_POW_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.