1 #include "caffe2/operators/elementwise_ops.h" 2 #include "caffe2/utils/eigen_utils.h" 10 UnaryElementwiseOp<BoolTypes, CPUContext, NotFunctor<CPUContext>>);
11 REGISTER_CPU_OPERATOR(
13 UnaryElementwiseOp<NumericTypes, CPUContext, SignFunctor<CPUContext>>);
15 #define REGISTER_CPU_COMPARE_OPERATOR(Op) \ 16 REGISTER_CPU_OPERATOR( \ 18 BinaryElementwiseOp< \ 19 TensorTypes<bool, int32_t, int64_t, float, double>, \ 21 Op##Functor<CPUContext>, \ 24 REGISTER_CPU_COMPARE_OPERATOR(EQ);
25 REGISTER_CPU_COMPARE_OPERATOR(NE);
26 REGISTER_CPU_COMPARE_OPERATOR(LT);
27 REGISTER_CPU_COMPARE_OPERATOR(LE);
28 REGISTER_CPU_COMPARE_OPERATOR(GT);
29 REGISTER_CPU_COMPARE_OPERATOR(GE);
31 #undef REGISTER_CPU_COMPARE_OPERATOR 33 #define REGISTER_CPU_LOGICAL_BINARY_OPERATOR(Op) \ 34 REGISTER_CPU_OPERATOR( \ 35 Op, BinaryElementwiseOp<BoolTypes, CPUContext, Op##Functor<CPUContext>>) 37 REGISTER_CPU_LOGICAL_BINARY_OPERATOR(And);
38 REGISTER_CPU_LOGICAL_BINARY_OPERATOR(Or);
39 REGISTER_CPU_LOGICAL_BINARY_OPERATOR(Xor);
41 #undef REGISTER_CPU_LOGICAL_BINARY_OPERATOR 43 #define REGISTER_CPU_BITWISE_BINARY_OPERATOR(Op) \ 44 REGISTER_CPU_OPERATOR( \ 46 BinaryElementwiseOp<IntBoolTypes, CPUContext, Op##Functor<CPUContext>>) 48 REGISTER_CPU_BITWISE_BINARY_OPERATOR(BitwiseAnd);
49 REGISTER_CPU_BITWISE_BINARY_OPERATOR(BitwiseOr);
50 REGISTER_CPU_BITWISE_BINARY_OPERATOR(BitwiseXor);
52 #undef REGISTER_CPU_BITWISE_BINARY_OPERATOR 55 void SRLHelper::sum2one(
const T* x,
T* y,
size_t n) {
56 *y = ConstEigenArrayMap<T>(x, n, 1).sum();
60 void SRLHelper::RunWithBroadcastFront(
66 EigenArrayMap<T>(y, n, 1) = ConstEigenArrayMap<T>(x, n, pre).rowwise().sum();
70 void SRLHelper::RunWithBroadcastBack(
76 EigenArrayMap<T>(y, 1, n) = ConstEigenArrayMap<T>(x, post, n).colwise().sum();
80 void SRLHelper::RunWithBroadcast2(
87 for (
int i = 0; i < n; ++i) {
89 for (
int j = 0; j < pre; ++j) {
90 for (
int k = 0; k < post; ++k) {
91 y[i] += a[(j * n + i) * post + k];
99 bool SumReduceLikeOp<CPUContext>::DoRunWithType() {
100 const auto&
A = Input(0);
101 const auto&
B = Input(1);
103 CAFFE_ENFORCE(!IsInputOutputAlias(1, 0),
"In-place is not allowed.");
104 auto*
C = Output(0,
B.sizes(), at::dtype<T>());
105 const T* Adata =
A.template data<T>();
106 auto* Cdata =
C->template mutable_data<T>();
107 if (
B.numel() == 1) {
108 auto count =
A.numel();
109 SRLHelper::sum2one<T>(Adata, Cdata, count);
112 std::tie(pre, n, post) =
113 elementwise_ops_utils::ComputeLegacyBroadcastSizes(
A,
B, axis_);
115 SRLHelper::RunWithBroadcastFront<T>(Adata, Cdata, pre, n, &context_);
116 }
else if (pre == 1) {
117 SRLHelper::RunWithBroadcastBack<T>(Adata, Cdata, post, n, &context_);
119 SRLHelper::RunWithBroadcast2<T>(Adata, Cdata, pre, n, post, &context_);
125 REGISTER_CPU_OPERATOR(SumReduceLike, SumReduceLikeOp<CPUContext>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...