1 #ifndef CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ 2 #define CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_ 9 #include "caffe2/core/common_omp.h" 10 #include "caffe2/core/context.h" 11 #include "caffe2/core/logging.h" 12 #include "caffe2/core/operator.h" 13 #include "caffe2/core/tensor.h" 14 #include "caffe2/operators/elementwise_ops_utils.h" 15 #include "caffe2/utils/eigen_utils.h" 16 #include "caffe2/utils/math.h" 20 using NumericTypes = TensorTypes<int32_t, int64_t, float, double>;
21 using IntTypes = TensorTypes<int32_t, int64_t>;
22 using BoolTypes = TensorTypes<bool>;
23 using IntBoolTypes = TensorTypes<int32_t, int64_t, bool>;
43 USE_OPERATOR_CONTEXT_FUNCTIONS;
45 template <
class... Args>
49 bool RunOnDevice()
override {
54 bool DoRunWithType() {
55 const auto& X = Input(0);
58 0, X.sizes(), at::dtype<typename OutputTypeMap::template type<T>>());
62 Y->template mutable_data<
typename OutputTypeMap::template type<T>>(),
73 template <
class Functor>
77 template <
typename TIn,
typename TOut,
class Context>
78 bool operator()(
const int size,
const TIn* X, TOut* Y, Context* context)
80 return functor(size, X, Y, context);
107 USE_OPERATOR_CONTEXT_FUNCTIONS;
109 template <
class... Args>
112 OP_SINGLE_ARG(
bool,
"broadcast", legacy_broadcast_,
false),
113 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
114 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
string(
"")),
115 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW"),
117 if (legacy_broadcast_) {
123 "Args axis and axis_str cannot be used simultaneously.");
124 }
else if (axis_str_.size()) {
127 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
128 const size_t semantic_axis_ = order_.find(axis_str_);
132 "Unrecognizable axis string ",
134 " from order string ",
136 axis_ = semantic_axis_;
139 axis_ == -1 && axis_str_.empty(),
140 "Do not specify axis or axis_str if broadcast is not enabled.");
145 bool RunOnDevice()
override {
149 template <
typename T>
150 bool DoRunWithType() {
151 const auto&
A = Input(0);
152 const auto&
B = Input(1);
154 const T* A_data =
A.template data<T>();
155 const T* B_data =
B.template data<T>();
156 std::vector<int> A_dims;
157 std::vector<int> B_dims;
158 std::vector<int64_t> C_dims;
160 if (legacy_broadcast_) {
162 !IsInputOutputAlias(1, 0),
163 "In-place is allowed only with the first tensor when " 164 "legacy-broadcasting");
165 C_dims =
A.sizes().vec();
166 if (
B.numel() == 1) {
167 A_dims = {
static_cast<int>(
A.numel())};
171 std::tie(pre, n, post) =
172 elementwise_ops_utils::ComputeLegacyBroadcastSizes(
A,
B, axis_);
174 static_cast<int>(pre), static_cast<int>(n),
static_cast<int>(post)};
175 B_dims = {
static_cast<int>(n), 1};
179 A.sizes().cbegin(),
A.sizes().cend(), std::back_inserter(A_dims));
181 B.sizes().cbegin(),
B.sizes().cend(), std::back_inserter(B_dims));
184 elementwise_ops_utils::ComputeBinaryBroadcastForwardDims(
187 C_dims_int.cbegin(), C_dims_int.cend(), std::back_inserter(C_dims));
188 if (IsInputOutputAlias(0, 0)) {
189 CAFFE_ENFORCE_EQ(C_dims_int, A_dims);
190 }
else if (IsInputOutputAlias(1, 0)) {
191 CAFFE_ENFORCE_EQ(C_dims_int, B_dims);
196 0, C_dims, at::dtype<
typename OutputTypeMap::template type<T>>());
198 C->template mutable_data<typename OutputTypeMap::template type<T>>();
199 return functor_.Forward(A_dims, B_dims, A_data, B_data, C_data, &context_);
203 const bool legacy_broadcast_;
205 const std::string axis_str_;
206 const std::string order_;
219 USE_OPERATOR_CONTEXT_FUNCTIONS;
221 template <
class... Args>
224 OP_SINGLE_ARG(
bool,
"broadcast", legacy_broadcast_,
false),
225 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
226 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
""),
227 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW"),
229 if (legacy_broadcast_) {
235 "Args axis and axis_str cannot be used simultaneously.");
236 }
else if (axis_str_.size()) {
239 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
240 const size_t semantic_axis_ = order_.find(axis_str_);
244 "Unrecognizable axis string ",
246 " from order string ",
248 axis_ = semantic_axis_;
251 axis_ == -1 && axis_str_.empty(),
252 "Do not specify axis or axis_str if broadcast is not enabled.");
257 bool RunOnDevice()
override {
261 template <
typename T>
262 bool DoRunWithType() {
263 const auto& dC = Input(0);
264 const auto&
A = Input(1);
265 const auto&
B = Input(2);
269 if (legacy_broadcast_) {
270 if (
B.numel() == 1) {
271 A_dims = {
static_cast<int>(
A.numel())};
275 std::tie(pre, n, post) =
276 elementwise_ops_utils::ComputeLegacyBroadcastSizes(
A,
B, axis_);
278 static_cast<int>(pre), static_cast<int>(n),
static_cast<int>(post)};
279 B_dims = {
static_cast<int>(n), 1};
283 A.sizes().cbegin(),
A.sizes().cend(), std::back_inserter(A_dims));
285 B.sizes().cbegin(),
B.sizes().cend(), std::back_inserter(B_dims));
287 const typename OutputTypeMap::template type<T>* C_data =
nullptr;
288 if (InputSize() == 4) {
289 const auto&
C = Input(3);
290 C_data =
C.template data<typename OutputTypeMap::template type<T>>();
292 const auto* dC_data =
293 dC.template data<typename GradientTypeMap::template type<T>>();
294 const T* A_data =
A.template data<T>();
295 const T* B_data =
B.template data<T>();
297 0,
A.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
299 1,
B.sizes(), at::dtype<typename GradientTypeMap::template type<T>>());
301 dA->template mutable_data<typename GradientTypeMap::template type<T>>();
303 dB->template mutable_data<typename GradientTypeMap::template type<T>>();
304 return functor_.Backward(
317 const bool legacy_broadcast_;
319 const std::string axis_str_;
320 const std::string order_;
325 template <
class Functor>
329 template <
typename TIn,
typename TOut,
class Context>
331 const std::vector<int>& A_dims,
332 const std::vector<int>& B_dims,
336 Context* context)
const {
337 return functor.Forward(A_dims, B_dims, A_data, B_data, C_data, context);
340 template <
typename TGrad,
typename TIn,
typename TOut,
class Context>
342 const std::vector<int>& A_dims,
343 const std::vector<int>& B_dims,
344 const TGrad* dC_data,
350 Context* context)
const {
351 return functor.Backward(
393 BinaryFunctorWithDefaultCtor<Functor>,
398 template <
class Context>
400 bool operator()(
const int N,
const bool* X,
bool* Y, Context* context)
const {
401 math::Not(N, X, Y, context);
406 template <
class Context>
408 template <
typename T>
409 bool operator()(
const int N,
const T* X,
T* Y, Context* context)
const {
410 math::Sign(N, X, Y, context);
416 #define C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(FunctorName) \ 417 template <class Context> \ 418 struct FunctorName##Functor { \ 419 template <typename TIn, typename TOut> \ 421 const std::vector<int>& A_dims, \ 422 const std::vector<int>& B_dims, \ 426 Context* context) const { \ 441 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(EQ);
442 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(NE);
443 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(LT);
444 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(LE);
445 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(GT);
446 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(GE);
449 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(And);
450 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(Or);
451 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(Xor);
454 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(BitwiseAnd);
455 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(BitwiseOr);
456 C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR(BitwiseXor);
458 #undef C10_DECLARE_FOWARD_ONLY_BINARY_FUNCTOR 460 namespace SRLHelper {
462 template <
typename T>
463 void sum2one(
const T* a,
T* y,
size_t n);
465 template <
typename T>
466 void RunWithBroadcastFront(
const T* a,
T* y,
size_t pre,
size_t n,
CPUContext*);
468 template <
typename T>
469 void RunWithBroadcastBack(
const T* a,
T* y,
size_t post,
size_t n,
CPUContext*);
471 template <
typename T>
472 void RunWithBroadcast2(
484 template <
class Context>
487 USE_OPERATOR_CONTEXT_FUNCTIONS;
488 template <
class... Args>
491 OP_SINGLE_ARG(
int,
"axis", axis_, -1),
492 OP_SINGLE_ARG(
string,
"axis_str", axis_str_,
""),
493 OP_SINGLE_ARG(
string,
"order", order_,
"NCHW") {
499 "Args axis and axis_str cannot be used simultaneously.");
500 }
else if (axis_str_.size()) {
503 axis_str_.size(), 1,
"Unsupported axis string", axis_str_);
504 size_t semantic_axis = order_.find(axis_str_);
508 "Unrecognizable axis string ",
510 " from order string ",
512 axis_ = semantic_axis;
516 bool RunOnDevice()
override {
520 template <
typename T>
521 bool DoRunWithType();
527 Tensor ones_{Context::GetDeviceType()};
528 Tensor sum_buffer_{Context::GetDeviceType()};
533 #endif // CAFFE2_OPERATORS_ELEMENTWISE_OPS_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...