1 #ifndef CAFFE2_OPERATORS_MEAN_OPS_H_ 2 #define CAFFE2_OPERATORS_MEAN_OPS_H_ 4 #include "caffe2/core/common_omp.h" 5 #include "caffe2/core/context.h" 6 #include "caffe2/core/logging.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/types.h" 9 #include "caffe2/utils/math.h" 10 #include "caffe2/utils/proto_utils.h" 14 template <
class Context>
17 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 USE_SIMPLE_CTOR_DTOR(
MeanOp)
21 bool DoRunWithType() {
22 auto& input0 =
Input(0);
24 auto* output = Output(0, input0.sizes(), at::dtype<T>());
25 output->CopyFrom(input0,
true );
27 if (InputSize() == 1) {
32 for (
int i = 1; i < InputSize(); ++i) {
33 if (output->sizes() !=
Input(i).sizes()) {
35 "Check failed: output->sizes() == Input(i).sizes().",
36 "Description: Input #",
40 " should match output dimension: ",
45 T* output_data = output->template mutable_data<T>();
46 for (
int i = 1; i < InputSize(); ++i) {
50 Input(i).template data<T>(),
65 bool RunOnDevice()
override {
66 if (
Input(0).
template IsType<float>()) {
67 return DoRunWithType<float>();
70 "Mean operator only supports 32-bit float, but",
71 " input was of type ",
72 Input(0).dtype().name());
77 template <
class Context>
80 USE_OPERATOR_CONTEXT_FUNCTIONS;
82 template <
class... Args>
87 bool DoRunWithType() {
89 const auto* dY_data = dY.template data<T>();
90 int size = dY.numel();
92 int num_inputs = OutputSize();
93 float scale = 1.0f / num_inputs;
97 auto* dX0 = Output(0, dY.sizes(), at::dtype<T>());
99 size, scale, dY_data, dX0->template mutable_data<T>(), &context_);
102 for (
int i = 1; i < num_inputs; i++) {
103 auto* cur_dX = Output(i);
104 cur_dX->ResizeLike(dY);
105 cur_dX->CopyFrom(*dX0,
true );
111 bool RunOnDevice()
override {
112 if (
Input(0).
template IsType<float>()) {
113 return DoRunWithType<float>();
116 "Mean operator only supports 32-bit float, but",
117 " input was of type ",
118 Input(0).dtype().name());
125 #endif // CAFFE2_OPERATORS_MEAN_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...