1 #include "caffe2/operators/affine_channel_op.h" 2 #include "caffe2/utils/eigen_utils.h" 11 void AffineChannelScaleBiasBackwardNCHW(
21 const int stride = C * HxW;
22 EigenVectorArrayMap<T> dscale_arr(dscale, C);
23 EigenVectorArrayMap<T> dbias_arr(dbias, C);
26 for (
int i = 0; i < N; ++i) {
27 ConstEigenArrayMap<T> dY_arr(dY_ptr, HxW, C);
28 ConstEigenArrayMap<T> X_arr(X_ptr, HxW, C);
29 dscale_arr += (dY_arr * X_arr).colwise().sum();
30 dbias_arr += dY_arr.colwise().sum();
37 void AffineChannelScaleBiasBackwardNHWC(
45 ConstEigenArrayMap<T> dY_arr(dY, C, N * HxW);
46 ConstEigenArrayMap<T> X_arr(X, C, N * HxW);
47 EigenVectorMap<T>(dscale, C) = (dY_arr * X_arr).rowwise().sum();
48 EigenVectorMap<T>(dbias, C) = dY_arr.rowwise().sum();
54 bool AffineChannelGradientOp<float, CPUContext>::RunOnDeviceWithOrderNCHW() {
55 const auto& dY = Input(0);
56 const auto& scale = is_learnable_ ? Input(2) : Input(1);
58 auto* dX = Output(0, dY.sizes(), at::dtype<float>());
59 const int N = dY.dim32(0);
60 const int C = dY.dim32(1);
61 const int HxW = dY.numel() / (N * C);
62 const float* dY_data = dY.data<
float>();
63 const float* scale_data = scale.data<
float>();
64 const std::array<int, 3> X_dims = {N, C, HxW};
65 const std::array<int, 3> scale_dims = {1, C, 1};
66 math::Mul<float, CPUContext>(
73 dX->template mutable_data<float>(),
76 const auto& X = Input(1);
77 const float* X_data = X.data<
float>();
79 auto* dscale = Output(1, scale.sizes(), at::dtype<float>());
80 auto* dbias = Output(2, scale.sizes(), at::dtype<float>());
81 AffineChannelScaleBiasBackwardNCHW<float>(
87 dscale->template mutable_data<float>(),
88 dbias->template mutable_data<float>());
94 bool AffineChannelGradientOp<float, CPUContext>::RunOnDeviceWithOrderNHWC() {
95 const auto& dY = Input(0);
96 const auto& scale = is_learnable_ ? Input(2) : Input(1);
98 auto* dX = Output(0, dY.sizes(), at::dtype<float>());
99 const int ndim = dY.dim();
100 const int C = dY.dim32(ndim - 1);
101 const int rows = dY.numel() / C;
103 const float* dY_data = dY.data<
float>();
104 const float* scale_data = scale.data<
float>();
105 math::RowwiseMul<float, CPUContext>(
110 dX->template mutable_data<float>(),
113 const auto& X = Input(1);
114 const float* X_data = X.data<
float>();
115 const int N = X.dim32(0);
116 const int HxW = rows / N;
118 auto* dscale = Output(1, scale.sizes(), at::dtype<float>());
119 auto* dbias = Output(2, scale.sizes(), at::dtype<float>());
120 AffineChannelScaleBiasBackwardNHWC<float>(
126 dscale->template mutable_data<float>(),
127 dbias->template mutable_data<float>());
132 REGISTER_CPU_OPERATOR(AffineChannel, AffineChannelOp<float, CPUContext>);
133 REGISTER_CPU_OPERATOR(
134 AffineChannelGradient,
135 AffineChannelGradientOp<float, CPUContext>);
137 OPERATOR_SCHEMA(AffineChannel)
140 .AllowInplace({{0, 0}})
142 Applies a separate affine transformation to each channel of the input. Useful 143 for replacing spatial batch norm with its equivalent fixed transformation. 145 .Input(0, "X",
"Feature map input with order NCHW or NHWC.")
149 "1D input of shape (C); the c-th element is the scale factor of the " 150 "affine transformation for the c-th channel of the input.")
154 "1D input of shape (C); the c-th element is the bias of the affine " 155 "transformation for the c-th channel of the input.")
156 .Output(0,
"Y",
"Output with the same order of Input.");
158 OPERATOR_SCHEMA(AffineChannelGradient)
161 .AllowInplace({{0, 0}});
165 class GetAffineChannelGradient :
public GradientMakerBase {
166 using GradientMakerBase::GradientMakerBase;
167 std::vector<OperatorDef> GetGradientDefs()
override {
168 ArgumentHelper arg_helper(def_);
169 const bool is_learnable =
170 arg_helper.GetSingleArgument(
"is_learnable",
false);
172 return SingleGradientDef(
173 "AffineChannelGradient",
175 std::vector<std::string>{GO(0), I(0), I(1)},
176 std::vector<std::string>{GI(0), GI(1), GI(2)});
178 return SingleGradientDef(
179 "AffineChannelGradient",
181 std::vector<std::string>{GO(0), I(1)},
182 std::vector<std::string>{GI(0)});
189 REGISTER_GRADIENT(AffineChannel, GetAffineChannelGradient);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...