1 #ifndef CAFFE2_OPERATORS_MOMENTS_OP_H_ 2 #define CAFFE2_OPERATORS_MOMENTS_OP_H_ 7 #include "caffe2/core/context.h" 8 #include "caffe2/core/operator.h" 9 #include "caffe2/utils/math.h" 13 template <
typename T,
class Context>
16 USE_OPERATOR_CONTEXT_FUNCTIONS;
18 template <
class... Args>
21 axes_(this->
template GetRepeatedArgument<int>(
"axes")),
22 OP_SINGLE_ARG(
bool,
"keepdims", keep_dims_,
true) {}
24 bool RunOnDevice()
override {
25 const auto& X =
Input(0);
27 const int ndim = X.dim();
30 std::iota(axes_.begin(), axes_.end(), 0);
32 std::sort(axes_.begin(), axes_.end());
33 CAFFE_ENFORCE_GE(axes_.front(), 0,
"Axes ids must be non-negative.");
37 "Axes ids must be smaller than the dimensions of input.");
39 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
40 std::vector<int> Y_dims = X_dims;
41 for (
const int axis : axes_) {
44 std::vector<std::int64_t> output_dims;
45 output_dims.reserve(ndim);
46 std::size_t cur_axis = 0;
47 for (
int i = 0; i < ndim; ++i) {
48 if (cur_axis < axes_.size() && i == axes_[cur_axis]) {
50 output_dims.push_back(1);
54 output_dims.push_back(X_dims[i]);
57 auto* mean = Output(0, output_dims, at::dtype<T>());
58 auto* var = Output(1, output_dims, at::dtype<T>());
59 math::Moments<float, Context>(
64 mean->template mutable_data<T>(),
65 var->template mutable_data<T>(),
71 std::vector<int> axes_;
75 template <
typename T,
class Context>
78 USE_OPERATOR_CONTEXT_FUNCTIONS;
80 template <
class... Args>
83 axes_(this->
template GetRepeatedArgument<int>(
"axes")) {}
85 bool RunOnDevice()
override {
86 const auto& dmean =
Input(0);
87 const auto& dvariance =
Input(1);
88 const auto& X =
Input(2);
89 const auto& mean =
Input(3);
91 const int ndim = X.dim();
94 std::iota(axes_.begin(), axes_.end(), 0);
96 std::sort(axes_.begin(), axes_.end());
97 CAFFE_ENFORCE_GE(axes_.front(), 0,
"Axes ids must be non-negative.");
101 "Axes ids must be smaller than the dimensions of input.");
103 const std::vector<int> dX_dims(X.sizes().cbegin(), X.sizes().cend());
104 std::vector<int> dY_dims = dX_dims;
105 for (
const int axis : axes_) {
108 auto* dX = Output(0, X.sizes(), at::dtype<T>());
112 dmean.template data<T>(),
113 dvariance.template data<T>(),
114 X.template data<T>(),
115 mean.template data<T>(),
116 dX->template mutable_data<T>());
121 const std::vector<int>& dY_dims,
122 const std::vector<int>& dX_dims,
124 const T* dvariance_data,
129 std::vector<int> axes_;
134 #endif // CAFFE2_OPERATORS_MOMENTS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...