1 #ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_ 2 #define CAFFE2_OPERATORS_REDUCTION_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
typename T,
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
19 average_(this->
template GetSingleArgument<bool>(
"average",
false)) {}
24 average_(this->
template GetSingleArgument<bool>(
"average",
false)) {}
26 :
Operator<Context>(schema, std::move(inputs), std::move(outputs)), average_(average) {}
29 bool RunOnDevice()
override {
32 auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
34 T* data = sum->template mutable_data<T>();
36 math::Sum<T, Context>(
37 X.numel(), X.template data<T>(), data, &context_, &scratch_);
38 if (average_ && X.numel() > 0) {
39 math::Scale<float, T, Context>(
41 static_cast<T>(1.) / X.numel(),
42 sum->template data<T>(),
51 Tensor scratch_{Context::GetDeviceType()};
54 template <
typename T,
class Context>
57 USE_OPERATOR_CONTEXT_FUNCTIONS;
59 template <
class... Args>
62 ~SumElementsIntOp() {}
64 bool RunOnDevice()
override {
67 auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
68 T* data = sum->template mutable_data<T>();
69 math::Sum<T, Context>(
70 X.numel(), X.template data<T>(), data, &context_, &scratch_);
75 Tensor scratch_{Context::GetDeviceType()};
78 template <
typename T,
class Context>
81 USE_OPERATOR_CONTEXT_FUNCTIONS;
85 average_(this->
template GetSingleArgument<bool>(
"average",
false)) {}
90 average_(this->
template GetSingleArgument<bool>(
"average",
false)) {}
92 :
Operator<Context>(schema, std::move(inputs), std::move(outputs)), average_(average) {}
95 bool RunOnDevice()
override;
101 template <
class Context>
105 USE_OPERATOR_CONTEXT_FUNCTIONS;
107 bool RunOnDevice()
override {
111 template <
typename T>
112 bool DoRunWithType() {
113 bool average = this->
template GetSingleArgument<bool>(
"average",
false);
116 auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
117 math::SumSqr<T, Context>(
119 X.template data<T>(),
120 sum->template mutable_data<T>(),
123 if (average && X.numel() > 0) {
124 math::Scale<float, T, Context>(
126 float(1.) / X.numel(),
127 sum->template data<T>(),
128 sum->template mutable_data<T>(),
135 Tensor scratch_{Context::GetDeviceType()};
138 template <
typename T,
class Context,
bool ROWWISE>
142 USE_OPERATOR_CONTEXT_FUNCTIONS;
144 bool RunOnDevice()
override {
146 CAFFE_ENFORCE_EQ(X.dim(), 3);
148 const int batch_size = X.dim32(0);
149 const int M = X.dim32(1);
150 const int N = X.dim32(2);
152 auto* Y = Output(0, {batch_size, ROWWISE ? M : N}, at::dtype<T>());
155 math::RowwiseMax<T, Context>(
158 X.template data<T>(),
159 Y->template mutable_data<T>(),
162 const int input_size = N * M;
163 for (
int i = 0; i < batch_size; ++i) {
164 math::ColwiseMax<T, Context>(
167 X.template data<T>() + i * input_size,
168 Y->template mutable_data<T>() + i * N,
176 template <
typename T,
class Context,
bool ROWWISE>
180 USE_OPERATOR_CONTEXT_FUNCTIONS;
182 bool RunOnDevice()
override;
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...