1 #include "caffe2/operators/moments_op.h" 8 template <
typename T,
class Context>
9 bool MomentsGradientOp<T, Context>::Compute(
10 const std::vector<int>& dY_dims,
11 const std::vector<int>& dX_dims,
13 const T* dvariance_data,
17 const int dY_size = std::accumulate(
18 dY_dims.cbegin(), dY_dims.cend(), 1, std::multiplies<int>());
19 const int dX_size = std::accumulate(
20 dX_dims.cbegin(), dX_dims.cend(), 1, std::multiplies<int>());
21 const int ndim = dX_dims.size();
22 std::vector<int> index(ndim, 0);
23 const T norm =
static_cast<T>(dY_size) / static_cast<T>(dX_size);
24 for (
int dX_index = 0; dX_index < dX_size; ++dX_index) {
26 math::utils::GetIndexFromDims(ndim, dY_dims.data(), index.data());
28 (dmean_data[dY_index] +
29 static_cast<T>(2) * (X_data[dX_index] - mean_data[dY_index]) *
30 dvariance_data[dY_index]) *
32 math::utils::IncreaseIndexInDims(ndim, dX_dims.data(), index.data());
37 REGISTER_CPU_OPERATOR(Moments, MomentsOp<float, CPUContext>);
38 REGISTER_CPU_OPERATOR(MomentsGradient, MomentsGradientOp<float, CPUContext>);
40 OPERATOR_SCHEMA(Moments)
44 Computes the mean and variance of the input tensor's element along the 45 provided axes. The resulted tensor has the same rank as the input if keepdims 47 If keepdims equals False, then the resulted tensor have the reduced dimension 52 "A list of integers, along which to reduce. If axes is not provided, " 53 "the op computes the element-wise mean and variance.")
56 "Keep the reduced dimension(s) or not, default True keeps the reduced " 58 .Input(0,
"data",
"An input tensor.")
59 .Output(0,
"mean",
"Reduced mean tensor.")
60 .Output(1,
"variance",
"Reduced variance tensor.");
62 OPERATOR_SCHEMA(MomentsGradient).NumInputs(4).NumOutputs(1);
66 class GetMomentsGradient :
public GradientMakerBase {
67 using GradientMakerBase::GradientMakerBase;
69 std::vector<OperatorDef> GetGradientDefs()
override {
70 return SingleGradientDef(
73 std::vector<std::string>{GO(0), GO(1), I(0), O(0)},
74 std::vector<std::string>{GI(0)});
80 REGISTER_GRADIENT(Moments, GetMomentsGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...