1 #ifndef CAFFE2_OPERATORS_REDUCE_OPS_H_ 2 #define CAFFE2_OPERATORS_REDUCE_OPS_H_ 8 #include "caffe2/core/context.h" 9 #include "caffe2/core/operator.h" 10 #include "caffe2/core/types.h" 11 #include "caffe2/utils/math.h" 15 template <
typename InputTypes,
class Context,
class Reducer>
18 USE_OPERATOR_CONTEXT_FUNCTIONS;
20 template <
class... Args>
23 axes_(this->
template GetRepeatedArgument<int>(
"axes")),
24 OP_SINGLE_ARG(
bool,
"keepdims", keep_dims_,
true) {}
26 bool RunOnDevice()
override {
31 bool DoRunWithType() {
32 const auto& X =
Input(0);
33 const int ndim = X.dim();
34 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
37 std::iota(axes_.begin(), axes_.end(), 0);
39 for (
auto& axis : axes_) {
40 axis = X.canonical_axis_index(axis);
42 std::sort(axes_.begin(), axes_.end());
43 CAFFE_ENFORCE_GE(axes_.front(), 0,
"Axes ids must be non-negative.");
47 "Axes ids must be smaller than the dimensions of input.");
49 std::vector<int64_t> output_dims;
50 output_dims.reserve(ndim);
51 std::size_t cur_axis = 0;
52 for (
int i = 0; i < ndim; ++i) {
53 if (cur_axis < axes_.size() && i == axes_[cur_axis]) {
55 output_dims.push_back(1);
59 output_dims.push_back(X_dims[i]);
62 auto* Y = Output(0, output_dims, at::dtype<T>());
64 std::vector<int> Y_dims = X_dims;
65 for (
const int axis : axes_) {
69 return reducer_.template Forward<T>(
73 Y->template mutable_data<T>(),
78 std::vector<int> axes_;
80 const Reducer reducer_{};
83 template <
typename InputTypes,
class Context,
class Reducer>
86 USE_OPERATOR_CONTEXT_FUNCTIONS;
88 template <
class... Args>
91 axes_(this->
template GetRepeatedArgument<int>(
"axes")) {}
93 bool RunOnDevice()
override {
98 bool DoRunWithType() {
99 const auto& dY =
Input(0);
100 const auto& X =
Input(1);
101 const auto& Y =
Input(2);
103 const int ndim = X.dim();
106 std::iota(axes_.begin(), axes_.end(), 0);
108 for (
auto& axis : axes_) {
109 axis = X.canonical_axis_index(axis);
111 std::sort(axes_.begin(), axes_.end());
112 CAFFE_ENFORCE_GE(axes_.front(), 0,
"Axes ids must be non-negative.");
116 "Axes ids must be smaller than the dimensions of input.");
118 const std::vector<int> dX_dims(X.sizes().cbegin(), X.sizes().cend());
119 std::vector<int> dY_dims = dX_dims;
120 for (
const int axis : axes_) {
123 auto* dX = Output(0, X.sizes(), at::dtype<T>());
124 return reducer_.template Backward<T>(
127 dY.template data<T>(),
128 X.template data<T>(),
129 Y.template data<T>(),
130 dX->template mutable_data<T>(),
135 std::vector<int> axes_;
136 const Reducer reducer_{};
139 template <
class Context>
141 template <
typename T>
143 const std::vector<int>& X_dims,
144 const std::vector<int>& Y_dims,
147 Context* context)
const {
148 math::ReduceMin<T, Context>(
159 template <
typename T>
161 const std::vector<int>& dY_dims,
162 const std::vector<int>& dX_dims,
167 Context* context)
const;
170 template <
class Context>
172 template <
typename T>
174 const std::vector<int>& X_dims,
175 const std::vector<int>& Y_dims,
178 Context* context)
const {
179 math::ReduceMax<T, Context>(
190 template <
typename T>
192 const std::vector<int>& dY_dims,
193 const std::vector<int>& dX_dims,
198 Context* context)
const;
201 template <
class Context>
203 template <
typename T>
205 const std::vector<int>& X_dims,
206 const std::vector<int>& Y_dims,
209 Context* context)
const {
210 math::ReduceSum<T, Context>(
221 template <
typename T>
223 const std::vector<int>& dY_dims,
224 const std::vector<int>& dX_dims,
229 Context* context)
const {
243 template <
class Context>
245 template <
typename T>
247 const std::vector<int>& X_dims,
248 const std::vector<int>& Y_dims,
251 Context* context)
const {
252 math::ReduceMean<T, Context>(
263 template <
typename T>
265 const std::vector<int>& dY_dims,
266 const std::vector<int>& dX_dims,
271 Context* context)
const {
272 const int dY_size = std::accumulate(
273 dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
274 const int dX_size = std::accumulate(
275 dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
281 static_cast<T>(dY_size) / static_cast<T>(dX_size),
289 template <
class Context>
291 template <
typename T>
293 const std::vector<int>& X_dims,
294 const std::vector<int>& Y_dims,
297 Context* context)
const {
298 math::ReduceL1<T, Context>(
309 template <
typename T>
311 const std::vector<int>& dY_dims,
312 const std::vector<int>& dX_dims,
317 Context* context)
const;
320 template <
class Context>
322 template <
typename T>
324 const std::vector<int>& X_dims,
325 const std::vector<int>& Y_dims,
328 Context* context)
const {
329 math::ReduceL2<T, Context>(
340 template <
typename T>
342 const std::vector<int>& dY_dims,
343 const std::vector<int>& dX_dims,
348 Context* context)
const;
353 #endif // CAFFE2_OPERATORS_REDUCE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...