Caffe2 - C++ API
A deep learning, cross platform ML framework
reduce_ops.h
1 #ifndef CAFFE2_OPERATORS_REDUCE_OPS_H_
2 #define CAFFE2_OPERATORS_REDUCE_OPS_H_
3 
4 #include <algorithm>
5 #include <functional>
6 #include <vector>
7 
8 #include "caffe2/core/context.h"
9 #include "caffe2/core/operator.h"
10 #include "caffe2/core/types.h"
11 #include "caffe2/utils/math.h"
12 
13 namespace caffe2 {
14 
15 template <typename InputTypes, class Context, class Reducer>
16 class ReduceOp final : public Operator<Context> {
17  public:
18  USE_OPERATOR_CONTEXT_FUNCTIONS;
19 
20  template <class... Args>
21  explicit ReduceOp(Args&&... args)
22  : Operator<Context>(std::forward<Args>(args)...),
23  axes_(this->template GetRepeatedArgument<int>("axes")),
24  OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
25 
26  bool RunOnDevice() override {
27  return DispatchHelper<InputTypes>::call(this, Input(0));
28  }
29 
30  template <typename T>
31  bool DoRunWithType() {
32  const auto& X = Input(0);
33  const int ndim = X.dim();
34  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
35  if (axes_.empty()) {
36  axes_.resize(ndim);
37  std::iota(axes_.begin(), axes_.end(), 0);
38  } else {
39  for (auto& axis : axes_) {
40  axis = X.canonical_axis_index(axis);
41  }
42  std::sort(axes_.begin(), axes_.end());
43  CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative.");
44  CAFFE_ENFORCE_LT(
45  axes_.back(),
46  ndim,
47  "Axes ids must be smaller than the dimensions of input.");
48  }
49  std::vector<int64_t> output_dims;
50  output_dims.reserve(ndim);
51  std::size_t cur_axis = 0;
52  for (int i = 0; i < ndim; ++i) {
53  if (cur_axis < axes_.size() && i == axes_[cur_axis]) {
54  if (keep_dims_) {
55  output_dims.push_back(1);
56  }
57  ++cur_axis;
58  } else {
59  output_dims.push_back(X_dims[i]);
60  }
61  }
62  auto* Y = Output(0, output_dims, at::dtype<T>());
63 
64  std::vector<int> Y_dims = X_dims;
65  for (const int axis : axes_) {
66  Y_dims[axis] = 1;
67  }
68 
69  return reducer_.template Forward<T>(
70  X_dims,
71  Y_dims,
72  X.template data<T>(),
73  Y->template mutable_data<T>(),
74  &context_);
75  }
76 
77  private:
78  std::vector<int> axes_;
79  const int keep_dims_;
80  const Reducer reducer_{};
81 };
82 
83 template <typename InputTypes, class Context, class Reducer>
84 class ReduceGradientOp final : public Operator<Context> {
85  public:
86  USE_OPERATOR_CONTEXT_FUNCTIONS;
87 
88  template <class... Args>
89  explicit ReduceGradientOp(Args&&... args)
90  : Operator<Context>(std::forward<Args>(args)...),
91  axes_(this->template GetRepeatedArgument<int>("axes")) {}
92 
93  bool RunOnDevice() override {
94  return DispatchHelper<InputTypes>::call(this, Input(0));
95  }
96 
97  template <typename T>
98  bool DoRunWithType() {
99  const auto& dY = Input(0);
100  const auto& X = Input(1);
101  const auto& Y = Input(2);
102 
103  const int ndim = X.dim();
104  if (axes_.empty()) {
105  axes_.resize(ndim);
106  std::iota(axes_.begin(), axes_.end(), 0);
107  } else {
108  for (auto& axis : axes_) {
109  axis = X.canonical_axis_index(axis);
110  }
111  std::sort(axes_.begin(), axes_.end());
112  CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative.");
113  CAFFE_ENFORCE_LT(
114  axes_.back(),
115  ndim,
116  "Axes ids must be smaller than the dimensions of input.");
117  }
118  const std::vector<int> dX_dims(X.sizes().cbegin(), X.sizes().cend());
119  std::vector<int> dY_dims = dX_dims;
120  for (const int axis : axes_) {
121  dY_dims[axis] = 1;
122  }
123  auto* dX = Output(0, X.sizes(), at::dtype<T>());
124  return reducer_.template Backward<T>(
125  dY_dims,
126  dX_dims,
127  dY.template data<T>(),
128  X.template data<T>(),
129  Y.template data<T>(),
130  dX->template mutable_data<T>(),
131  &context_);
132  }
133 
134  private:
135  std::vector<int> axes_;
136  const Reducer reducer_{};
137 };
138 
139 template <class Context>
140 struct MinReducer {
141  template <typename T>
142  bool Forward(
143  const std::vector<int>& X_dims,
144  const std::vector<int>& Y_dims,
145  const T* X_data,
146  T* Y_data,
147  Context* context) const {
148  math::ReduceMin<T, Context>(
149  X_dims.size(),
150  X_dims.data(),
151  Y_dims.data(),
152  T(1),
153  X_data,
154  Y_data,
155  context);
156  return true;
157  }
158 
159  template <typename T>
160  bool Backward(
161  const std::vector<int>& dY_dims,
162  const std::vector<int>& dX_dims,
163  const T* dY_data,
164  const T* X_data,
165  const T* Y_data,
166  T* dX_data,
167  Context* context) const;
168 };
169 
170 template <class Context>
171 struct MaxReducer {
172  template <typename T>
173  bool Forward(
174  const std::vector<int>& X_dims,
175  const std::vector<int>& Y_dims,
176  const T* X_data,
177  T* Y_data,
178  Context* context) const {
179  math::ReduceMax<T, Context>(
180  X_dims.size(),
181  X_dims.data(),
182  Y_dims.data(),
183  T(1),
184  X_data,
185  Y_data,
186  context);
187  return true;
188  }
189 
190  template <typename T>
191  bool Backward(
192  const std::vector<int>& dY_dims,
193  const std::vector<int>& dX_dims,
194  const T* dY_data,
195  const T* X_data,
196  const T* Y_data,
197  T* dX_data,
198  Context* context) const;
199 };
200 
201 template <class Context>
202 struct SumReducer {
203  template <typename T>
204  bool Forward(
205  const std::vector<int>& X_dims,
206  const std::vector<int>& Y_dims,
207  const T* X_data,
208  T* Y_data,
209  Context* context) const {
210  math::ReduceSum<T, Context>(
211  X_dims.size(),
212  X_dims.data(),
213  Y_dims.data(),
214  T(1),
215  X_data,
216  Y_data,
217  context);
218  return true;
219  }
220 
221  template <typename T>
222  bool Backward(
223  const std::vector<int>& dY_dims,
224  const std::vector<int>& dX_dims,
225  const T* dY_data,
226  const T* /* X_data */,
227  const T* /* Y_data */,
228  T* dX_data,
229  Context* context) const {
230  math::Broadcast(
231  dY_dims.size(),
232  dY_dims.data(),
233  dX_dims.size(),
234  dX_dims.data(),
235  T(1),
236  dY_data,
237  dX_data,
238  context);
239  return true;
240  }
241 };
242 
243 template <class Context>
244 struct MeanReducer {
245  template <typename T>
246  bool Forward(
247  const std::vector<int>& X_dims,
248  const std::vector<int>& Y_dims,
249  const T* X_data,
250  T* Y_data,
251  Context* context) const {
252  math::ReduceMean<T, Context>(
253  X_dims.size(),
254  X_dims.data(),
255  Y_dims.data(),
256  T(1),
257  X_data,
258  Y_data,
259  context);
260  return true;
261  }
262 
263  template <typename T>
264  bool Backward(
265  const std::vector<int>& dY_dims,
266  const std::vector<int>& dX_dims,
267  const T* dY_data,
268  const T* /* X_data */,
269  const T* /* Y_data */,
270  T* dX_data,
271  Context* context) const {
272  const int dY_size = std::accumulate(
273  dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
274  const int dX_size = std::accumulate(
275  dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
276  math::Broadcast(
277  dY_dims.size(),
278  dY_dims.data(),
279  dX_dims.size(),
280  dX_dims.data(),
281  static_cast<T>(dY_size) / static_cast<T>(dX_size),
282  dY_data,
283  dX_data,
284  context);
285  return true;
286  }
287 };
288 
289 template <class Context>
290 struct L1Reducer {
291  template <typename T>
292  bool Forward(
293  const std::vector<int>& X_dims,
294  const std::vector<int>& Y_dims,
295  const T* X_data,
296  T* Y_data,
297  Context* context) const {
298  math::ReduceL1<T, Context>(
299  X_dims.size(),
300  X_dims.data(),
301  Y_dims.data(),
302  T(1),
303  X_data,
304  Y_data,
305  context);
306  return true;
307  }
308 
309  template <typename T>
310  bool Backward(
311  const std::vector<int>& dY_dims,
312  const std::vector<int>& dX_dims,
313  const T* dY_data,
314  const T* X_data,
315  const T* Y_data,
316  T* dX_data,
317  Context* context) const;
318 };
319 
320 template <class Context>
321 struct L2Reducer {
322  template <typename T>
323  bool Forward(
324  const std::vector<int>& X_dims,
325  const std::vector<int>& Y_dims,
326  const T* X_data,
327  T* Y_data,
328  Context* context) const {
329  math::ReduceL2<T, Context>(
330  X_dims.size(),
331  X_dims.data(),
332  Y_dims.data(),
333  T(1),
334  X_data,
335  Y_data,
336  context);
337  return true;
338  }
339 
340  template <typename T>
341  bool Backward(
342  const std::vector<int>& dY_dims,
343  const std::vector<int>& dX_dims,
344  const T* dY_data,
345  const T* X_data,
346  const T* Y_data,
347  T* dX_data,
348  Context* context) const;
349 };
350 
351 } // namespace caffe2
352 
353 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13