1 #include "caffe2/operators/batch_moments_op.h" 6 #include "caffe2/utils/eigen_utils.h" 7 #include "caffe2/utils/math.h" 12 bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNCHW(
19 math::Set<float, CPUContext>(C, 0.0f, mu, &context_);
20 math::Set<float, CPUContext>(C, 0.0f, var, &context_);
21 EigenVectorArrayMap<float> mu_arr(mu, C);
22 EigenVectorArrayMap<float> var_arr(var, C);
23 const float* X_ptr = X;
24 const int stride = C * HxW;
25 for (
int i = 0; i < N; ++i) {
26 ConstEigenArrayMap<float> X_arr(X_ptr, HxW, C);
27 mu_arr += X_arr.colwise().sum();
28 var_arr += X_arr.square().colwise().sum();
31 const float scale = 1.0f /
static_cast<float>(N * HxW);
32 math::Scale<float, float, CPUContext>(C, scale, mu, mu, &context_);
33 math::Scale<float, float, CPUContext>(C, scale, var, var, &context_);
38 bool BatchMomentsOp<float, CPUContext>::ComputeBatchMomentsNHWC(
45 ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
46 EigenVectorMap<float>(mu, C) = X_arr.rowwise().mean();
47 EigenVectorMap<float>(var, C) = X_arr.square().rowwise().mean();
52 bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNCHW(
60 ConstEigenVectorArrayMap<float> dmu_arr(dmu, C);
61 ConstEigenVectorArrayMap<float> dvar_arr(dvar, C);
62 const float* X_ptr = X;
64 const int stride = C * HxW;
65 for (
int i = 0; i < N; ++i) {
66 EigenArrayMap<float> dX_arr(dX_ptr, HxW, C);
67 dX_arr = ConstEigenArrayMap<float>(X_ptr, HxW, C).rowwise() *
68 dvar_arr.transpose() * 2.0f;
69 dX_arr.rowwise() += dmu_arr.transpose();
73 const float scale = 1.0f /
static_cast<float>(N * HxW);
74 math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
79 bool BatchMomentsGradientOp<float, CPUContext>::ComputeBatchMomentsGradientNHWC(
87 const float scale = 1.0f /
static_cast<float>(N * HxW);
88 EigenArrayMap<float> dX_arr(dX, C, N * HxW);
89 dX_arr = ConstEigenArrayMap<float>(X, C, N * HxW).colwise() *
90 ConstEigenVectorArrayMap<float>(dvar, C) * 2.0f;
91 dX_arr.colwise() += ConstEigenVectorArrayMap<float>(dmu, C);
92 math::Scale<float, float, CPUContext>(N * C * HxW, scale, dX, dX, &context_);
96 REGISTER_CPU_OPERATOR(BatchMoments, BatchMomentsOp<float, CPUContext>);
97 REGISTER_CPU_OPERATOR(
99 BatchMomentsGradientOp<float, CPUContext>);
101 OPERATOR_SCHEMA(BatchMoments).NumInputs(1).NumOutputs(2);
102 OPERATOR_SCHEMA(BatchMomentsGradient).NumInputs(3).NumOutputs(1);
106 class GetBatchMomentsGradient :
public GradientMakerBase {
107 using GradientMakerBase::GradientMakerBase;
109 std::vector<OperatorDef> GetGradientDefs()
override {
110 return SingleGradientDef(
111 "BatchMomentsGradient",
113 std::vector<std::string>{GO(0), GO(1), I(0)},
114 std::vector<std::string>{GI(0)});
120 REGISTER_GRADIENT(BatchMoments, GetBatchMomentsGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...