1 #ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ 2 #define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ 10 #include "caffe2/core/context.h" 11 #include "caffe2/core/operator.h" 12 #include "caffe2/utils/eigen_utils.h" 13 #include "caffe2/utils/math.h" 17 template <
class Context>
20 USE_OPERATOR_CONTEXT_FUNCTIONS;
22 template <
class... Args>
25 OP_SINGLE_ARG(
bool, OpSchema::Arg_IsTest, is_test_,
false),
26 OP_SINGLE_ARG(
double,
"epsilon", epsilon_, 1e-5),
27 OP_SINGLE_ARG(
float,
"momentum", momentum_, 0.9f),
28 order_(StringToStorageOrder(
29 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))),
30 OP_SINGLE_ARG(
int,
"num_batches", num_batches_, 1) {
33 StorageOrder::UNKNOWN,
34 "order should be either \"NCHW\" or \"NHWC\".");
36 (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 5));
37 CAFFE_ENFORCE_GT(epsilon_, 0);
38 CAFFE_ENFORCE_GE(momentum_, 0);
39 CAFFE_ENFORCE_LE(momentum_, 1);
42 virtual ~SpatialBNOp() =
default;
44 bool RunOnDevice()
override {
49 bool DoRunWithType() {
50 const auto& X =
Input(INPUT);
51 const auto& scale =
Input(SCALE);
52 const auto& bias =
Input(BIAS);
54 const int ndim = X.dim();
55 CAFFE_ENFORCE_GE(ndim, 3);
56 const int N = X.dim32(0);
58 (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
59 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
62 X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
64 CAFFE_ENFORCE_EQ(scale.numel(), C);
65 CAFFE_ENFORCE_EQ(bias.numel(), C);
67 auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
68 const T* X_data = X.template data<T>();
69 const T* scale_data = scale.template data<T>();
70 const T* bias_data = bias.template data<T>();
71 T* Y_data = Y->template mutable_data<T>();
73 &alpha_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
75 &beta_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
76 T* alpha_data = alpha_.template mutable_data<T>();
77 T* beta_data = beta_.template mutable_data<T>();
82 const auto& mean =
Input(EST_MEAN);
83 const auto& var =
Input(EST_VAR);
84 CAFFE_ENFORCE_EQ(mean.numel(), C);
85 CAFFE_ENFORCE_EQ(var.numel(), C);
90 mean.template data<T>(),
91 var.template data<T>(),
95 auto* saved_mean = Output(SAVED_MEAN, {C}, at::dtype<T>());
96 auto* saved_rstd = Output(SAVED_INV_STD, {C}, at::dtype<T>());
97 T* saved_mean_data = saved_mean->template mutable_data<T>();
98 T* saved_rstd_data = saved_rstd->template mutable_data<T>();
102 IsInputOutputAlias(3, 1),
"Input 3 and Output 1 should be alias.");
104 IsInputOutputAlias(4, 2),
"Input 4 and Output 2 should be alias.");
106 Tensor* running_mean =
nullptr;
107 Tensor* running_var =
nullptr;
108 const auto& mean =
Input(EST_MEAN);
109 const auto& var =
Input(EST_VAR);
110 if (mean.numel() != C) {
111 running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
112 C10_LOG_EVERY_MS(WARNING, 1000)
113 <<
"[Depreacated] Running mean is not initialized in " 114 "SpatialBatchNorm Op";
115 math::Set<T, Context>(
116 C,
T(0), running_mean->template mutable_data<T>(), &context_);
118 running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
120 if (var.numel() != C) {
121 running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
122 math::Set<T, Context>(
123 C,
T(0), running_var->template mutable_data<T>(), &context_);
124 C10_LOG_EVERY_MS(WARNING, 1000)
125 <<
"[Deprecated] Running variance is not initialized in " 126 "SpatialBatchNorm Op";
128 running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
131 T* running_mean_data = running_mean->template mutable_data<T>();
132 T* running_var_data = running_var->template mutable_data<T>();
134 math::Set<T, Context>(C,
T(0), saved_mean_data, &context_);
135 math::Set<T, Context>(C,
T(0), saved_rstd_data, &context_);
138 if (num_batches_ > 1) {
139 const auto& batch_mean_sum =
Input(BATCH_MEAN_SUM);
140 const auto& batch_var_sum =
Input(BATCH_VAR_SUM);
141 CAFFE_ENFORCE_EQ(batch_mean_sum.numel(), C);
142 CAFFE_ENFORCE_EQ(batch_var_sum.numel(), C);
143 ComputeBatchMoments<T>(
147 batch_mean_sum.template data<T>(),
148 batch_var_sum.template data<T>(),
152 if (order_ == StorageOrder::NCHW) {
153 const std::array<int, 3> X_dims_arr = {N, C, HxW};
154 const std::array<int, 3> Y_dims_arr = {1, C, 1};
155 math::Moments<T, Context>(
164 const std::array<int, 2> X_dims_arr = {N * HxW, C};
165 const std::array<int, 2> Y_dims_arr = {1, C};
166 math::Moments<T, Context>(
176 ComputeRunningMomentsAndFusedParam<T>(
188 if (order_ == StorageOrder::NCHW) {
189 math::AffineChannel<T, Context, StorageOrder::NCHW>(
190 N, C, HxW, X_data, alpha_data, beta_data, Y_data, &context_);
192 math::AffineChannel<T, Context, StorageOrder::NHWC>(
193 N, C, HxW, X_data, alpha_data, beta_data, Y_data, &context_);
200 template <
typename T>
201 void ComputeFusedParam(
209 EigenVectorArrayMap<T> alpha_arr(alpha, C);
210 EigenVectorArrayMap<T> beta_arr(beta, C);
211 alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
212 (ConstEigenVectorArrayMap<T>(var, C) +
static_cast<T>(epsilon_))
214 beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
215 alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
218 template <
typename T>
219 void ComputeBatchMoments(
223 const T* batch_mean_sum,
224 const T* batch_var_sum,
227 const T scale =
T(1) /
static_cast<T>(num_batches_ * N * HxW);
228 EigenVectorArrayMap<T> mean_arr(mean, C);
229 EigenVectorArrayMap<T> var_arr(var, C);
230 mean_arr = ConstEigenVectorArrayMap<T>(batch_mean_sum, C) * scale;
231 var_arr = ConstEigenVectorArrayMap<T>(batch_var_sum, C) * scale -
235 template <
typename T>
236 void ComputeRunningMomentsAndFusedParam(
247 const T a =
T(1) -
static_cast<T>(momentum_);
248 const T b =
static_cast<T>(momentum_);
249 math::Axpby<T, T, Context>(C, a, mean, b, running_mean, &context_);
250 math::Axpby<T, T, Context>(C, a, var, b, running_var, &context_);
251 math::InvStd<T, Context>(C,
static_cast<T>(epsilon_), var, rstd, &context_);
252 EigenVectorArrayMap<T> alpha_arr(alpha, C);
253 EigenVectorArrayMap<T> beta_arr(beta, C);
254 alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
255 ConstEigenVectorArrayMap<T>(rstd, C);
256 beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
257 alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
262 const float momentum_;
263 const StorageOrder order_;
264 const int num_batches_;
277 OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_STD);
280 template <
class Context>
283 USE_OPERATOR_CONTEXT_FUNCTIONS;
285 template <
class... Args>
288 OP_SINGLE_ARG(
double,
"epsilon", epsilon_, 1e-5),
289 order_(StringToStorageOrder(
290 this->
template GetSingleArgument<string>(
"order",
"NCHW"))),
291 OP_SINGLE_ARG(
int,
"num_batches", num_batches_, 1) {
294 StorageOrder::UNKNOWN,
295 "order should be either \"NCHW\" or \"NHWC\".");
296 CAFFE_ENFORCE(InputSize() == 5 || InputSize() == 7);
297 CAFFE_ENFORCE_EQ(OutputSize(), 3);
300 virtual ~SpatialBNGradientOp() =
default;
302 bool RunOnDevice()
override {
306 template <
typename T>
307 bool DoRunWithType() {
308 const auto& X =
Input(INPUT);
309 const auto& dY =
Input(OUTPUT_GRAD);
310 const auto& scale =
Input(SCALE);
311 const auto& mean =
Input(SAVED_MEAN);
312 const auto& rstd =
Input(SAVED_INV_STD);
313 const int ndim = X.dim();
314 CAFFE_ENFORCE_GE(ndim, 3);
315 const int N = X.dim32(0);
317 (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
318 const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
321 X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
323 CAFFE_ENFORCE_EQ(scale.numel(), C);
324 CAFFE_ENFORCE_EQ(mean.numel(), C);
325 CAFFE_ENFORCE_EQ(rstd.numel(), C);
327 auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
329 if (num_batches_ == 1) {
330 dscale_sizes = scale.sizes();
331 dbias_sizes = scale.sizes();
333 const auto& dscale_sum =
Input(AGGREGATE_SCALE_GRAD);
334 const auto& dbias_sum =
Input(AGGREGATE_BIAS_GRAD);
338 dscale_sizes = dscale_sum.sizes();
339 dbias_sizes = dbias_sum.sizes();
341 auto* dscale = Output(SCALE_GRAD, dscale_sizes, at::dtype<T>());
342 auto* dbias = Output(BIAS_GRAD, dbias_sizes, at::dtype<T>());
343 const T* X_data = X.template data<T>();
344 const T* dY_data = dY.template data<T>();
345 const T* scale_data = scale.template data<T>();
346 const T* mean_data = mean.template data<T>();
347 const T* rstd_data = rstd.template data<T>();
348 T* dX_data = dX->template mutable_data<T>();
349 T* dscale_data = dscale->template mutable_data<T>();
350 T* dbias_data = dbias->template mutable_data<T>();
353 math::Set<T, Context>(C,
T(0), dscale_data, &context_);
354 math::Set<T, Context>(C,
T(0), dbias_data, &context_);
358 &alpha_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
360 &beta_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
362 &gamma_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
363 T* alpha_data = alpha_.template mutable_data<T>();
364 T* beta_data = beta_.template mutable_data<T>();
365 T* gamma_data = gamma_.template mutable_data<T>();
366 if (num_batches_ > 1) {
367 const auto& dscale_sum =
Input(AGGREGATE_SCALE_GRAD);
368 const auto& dbias_sum =
Input(AGGREGATE_BIAS_GRAD);
369 ComputeMultiBatchScaleBiasGradientsAndFusedParams<T>(
376 dscale_sum.template data<T>(),
377 dbias_sum.template data<T>(),
384 ComputeScaleBiasGradientsAndFusedParams<T>(
401 N, C, HxW, dY_data, X_data, alpha_data, beta_data, gamma_data, dX_data);
407 template <
typename T>
408 void ComputeMultiBatchScaleBiasGradientsAndFusedParams(
423 template <
typename T>
424 void ComputeScaleBiasGradientsAndFusedParams(
440 template <
typename T>
441 void ComputeXGradient(
453 const StorageOrder order_;
454 const int num_batches_;
467 AGGREGATE_SCALE_GRAD,
468 AGGREGATE_BIAS_GRAD);
469 OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
474 #endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...