Caffe2 - C++ API
A deep learning, cross platform ML framework
mean_op.h
1 #ifndef CAFFE2_OPERATORS_MEAN_OPS_H_
2 #define CAFFE2_OPERATORS_MEAN_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/types.h"
9 #include "caffe2/utils/math.h"
10 #include "caffe2/utils/proto_utils.h"
11 
12 namespace caffe2 {
13 
14 template <class Context>
15 class MeanOp final : public Operator<Context> {
16  public:
17  USE_OPERATOR_CONTEXT_FUNCTIONS;
18  USE_SIMPLE_CTOR_DTOR(MeanOp)
19 
20  template <typename T>
21  bool DoRunWithType() {
22  auto& input0 = Input(0);
23 
24  auto* output = Output(0, input0.sizes(), at::dtype<T>());
25  output->CopyFrom(input0, true /*async*/);
26 
27  if (InputSize() == 1) {
28  return true;
29  }
30 
31  // Dimension checking
32  for (int i = 1; i < InputSize(); ++i) {
33  if (output->sizes() != Input(i).sizes()) {
34  CAFFE_THROW(
35  "Check failed: output->sizes() == Input(i).sizes().",
36  "Description: Input #",
37  i,
38  ", input dimension:",
39  Input(i).sizes(),
40  " should match output dimension: ",
41  output->sizes());
42  }
43  }
44 
45  T* output_data = output->template mutable_data<T>();
46  for (int i = 1; i < InputSize(); ++i) {
47  math::Add(
48  output->numel(),
49  output_data,
50  Input(i).template data<T>(),
51  output_data,
52  &context_);
53  }
54 
55  math::Scale(
56  output->numel(),
57  1.0f / InputSize(),
58  output_data,
59  output_data,
60  &context_);
61 
62  return true;
63  }
64 
65  bool RunOnDevice() override {
66  if (Input(0).template IsType<float>()) {
67  return DoRunWithType<float>();
68  } else {
69  CAFFE_THROW(
70  "Mean operator only supports 32-bit float, but",
71  " input was of type ",
72  Input(0).dtype().name());
73  }
74  }
75 };
76 
77 template <class Context>
78 class MeanGradientOp : public Operator<Context> {
79  public:
80  USE_OPERATOR_CONTEXT_FUNCTIONS;
81 
82  template <class... Args>
83  explicit MeanGradientOp(Args&&... args)
84  : Operator<Context>(std::forward<Args>(args)...) {}
85 
86  template <typename T>
87  bool DoRunWithType() {
88  auto& dY = Input(0);
89  const auto* dY_data = dY.template data<T>();
90  int size = dY.numel();
91 
92  int num_inputs = OutputSize();
93  float scale = 1.0f / num_inputs;
94 
95  // dX0 = scale * dY
96 
97  auto* dX0 = Output(0, dY.sizes(), at::dtype<T>());
98  math::Scale(
99  size, scale, dY_data, dX0->template mutable_data<T>(), &context_);
100 
101  // Copy the rest dX
102  for (int i = 1; i < num_inputs; i++) {
103  auto* cur_dX = Output(i);
104  cur_dX->ResizeLike(dY);
105  cur_dX->CopyFrom(*dX0, true /*async*/);
106  }
107 
108  return true;
109  }
110 
111  bool RunOnDevice() override {
112  if (Input(0).template IsType<float>()) {
113  return DoRunWithType<float>();
114  } else {
115  CAFFE_THROW(
116  "Mean operator only supports 32-bit float, but",
117  " input was of type ",
118  Input(0).dtype().name());
119  }
120  }
121 };
122 
123 } // namespace caffe2
124 
125 #endif // CAFFE2_OPERATORS_MEAN_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13