Caffe2 - C++ API
A deep learning, cross platform ML framework
reduction_ops.h
1 #ifndef CAFFE2_OPERATORS_REDUCTION_OPS_H_
2 #define CAFFE2_OPERATORS_REDUCTION_OPS_H_
3 
4 #include "caffe2/core/common_omp.h"
5 #include "caffe2/core/context.h"
6 #include "caffe2/core/logging.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <typename T, class Context>
13 class SumElementsOp : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<Context>(operator_def, ws),
19  average_(this->template GetSingleArgument<bool>("average", false)) {}
20  explicit SumElementsOp(const OperatorDef& operator_def, Workspace* ws, bool average)
21  : Operator<Context>(operator_def, ws), average_(average) {}
22  explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector<c10::IValue> inputs, std::vector<c10::IValue*> outputs)
23  : Operator<Context>(schema, std::move(inputs), std::move(outputs)),
24  average_(this->template GetSingleArgument<bool>("average", false)) {}
25  explicit SumElementsOp(const c10::FunctionSchema& schema, std::vector<c10::IValue> inputs, std::vector<c10::IValue*> outputs, bool average)
26  : Operator<Context>(schema, std::move(inputs), std::move(outputs)), average_(average) {}
27  ~SumElementsOp() {}
28 
29  bool RunOnDevice() override {
30  auto& X = Input(0);
31 
32  auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
33 
34  T* data = sum->template mutable_data<T>();
35 
36  math::Sum<T, Context>(
37  X.numel(), X.template data<T>(), data, &context_, &scratch_);
38  if (average_ && X.numel() > 0) {
39  math::Scale<float, T, Context>(
40  1,
41  static_cast<T>(1.) / X.numel(),
42  sum->template data<T>(),
43  data,
44  &context_);
45  }
46  return true;
47  }
48 
49  private:
50  bool average_;
51  Tensor scratch_{Context::GetDeviceType()};
52 };
53 
54 template <typename T, class Context>
55 class SumElementsIntOp : public Operator<Context> {
56  public:
57  USE_OPERATOR_CONTEXT_FUNCTIONS;
58 
59  template <class... Args>
60  explicit SumElementsIntOp(Args&&... args)
61  : Operator<Context>(std::forward<Args>(args)...) {}
62  ~SumElementsIntOp() {}
63 
64  bool RunOnDevice() override {
65  auto& X = Input(0);
66 
67  auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
68  T* data = sum->template mutable_data<T>();
69  math::Sum<T, Context>(
70  X.numel(), X.template data<T>(), data, &context_, &scratch_);
71  return true;
72  }
73 
74  private:
75  Tensor scratch_{Context::GetDeviceType()};
76 };
77 
78 template <typename T, class Context>
79 class SumElementsGradientOp : public Operator<Context> {
80  public:
81  USE_OPERATOR_CONTEXT_FUNCTIONS;
82 
83  explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws)
84  : Operator<Context>(operator_def, ws),
85  average_(this->template GetSingleArgument<bool>("average", false)) {}
86  explicit SumElementsGradientOp(const OperatorDef& operator_def, Workspace* ws, bool average)
87  : Operator<Context>(operator_def, ws), average_(average) {}
88  explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector<c10::IValue> inputs, std::vector<c10::IValue*> outputs)
89  : Operator<Context>(schema, std::move(inputs), std::move(outputs)),
90  average_(this->template GetSingleArgument<bool>("average", false)) {}
91  explicit SumElementsGradientOp(const c10::FunctionSchema& schema, std::vector<c10::IValue> inputs, std::vector<c10::IValue*> outputs, bool average)
92  : Operator<Context>(schema, std::move(inputs), std::move(outputs)), average_(average) {}
94 
95  bool RunOnDevice() override;
96 
97  private:
98  bool average_;
99 };
100 
101 template <class Context>
102 class SumSqrElementsOp : public Operator<Context> {
103  public:
104  USE_SIMPLE_CTOR_DTOR(SumSqrElementsOp)
105  USE_OPERATOR_CONTEXT_FUNCTIONS;
106 
107  bool RunOnDevice() override {
108  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
109  }
110 
111  template <typename T>
112  bool DoRunWithType() {
113  bool average = this->template GetSingleArgument<bool>("average", false);
114  auto& X = Input(0);
115 
116  auto* sum = Output(0, vector<int64_t>(), at::dtype<T>());
117  math::SumSqr<T, Context>(
118  X.numel(),
119  X.template data<T>(),
120  sum->template mutable_data<T>(),
121  &context_,
122  &scratch_);
123  if (average && X.numel() > 0) {
124  math::Scale<float, T, Context>(
125  1,
126  float(1.) / X.numel(),
127  sum->template data<T>(),
128  sum->template mutable_data<T>(),
129  &context_);
130  }
131  return true;
132  }
133 
134  private:
135  Tensor scratch_{Context::GetDeviceType()};
136 };
137 
138 template <typename T, class Context, bool ROWWISE>
139 class MaxReductionOp : public Operator<Context> {
140  public:
141  USE_SIMPLE_CTOR_DTOR(MaxReductionOp)
142  USE_OPERATOR_CONTEXT_FUNCTIONS;
143 
144  bool RunOnDevice() override {
145  auto& X = Input(0);
146  CAFFE_ENFORCE_EQ(X.dim(), 3);
147 
148  const int batch_size = X.dim32(0);
149  const int M = X.dim32(1);
150  const int N = X.dim32(2);
151 
152  auto* Y = Output(0, {batch_size, ROWWISE ? M : N}, at::dtype<T>());
153 
154  if (ROWWISE) {
155  math::RowwiseMax<T, Context>(
156  batch_size * M,
157  N,
158  X.template data<T>(),
159  Y->template mutable_data<T>(),
160  &context_);
161  } else {
162  const int input_size = N * M;
163  for (int i = 0; i < batch_size; ++i) {
164  math::ColwiseMax<T, Context>(
165  M,
166  N,
167  X.template data<T>() + i * input_size,
168  Y->template mutable_data<T>() + i * N,
169  &context_);
170  }
171  }
172  return true;
173  }
174 };
175 
176 template <typename T, class Context, bool ROWWISE>
177 class MaxReductionGradientOp : public Operator<Context> {
178  public:
179  USE_SIMPLE_CTOR_DTOR(MaxReductionGradientOp)
180  USE_OPERATOR_CONTEXT_FUNCTIONS;
181 
182  bool RunOnDevice() override;
183 };
184 
185 } // namespace caffe2
186 
187 #endif
Definition: any.cpp:108
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13