Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_batch_norm_op.h
1 #ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
2 #define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
3 
4 #include <algorithm>
5 #include <array>
6 #include <functional>
7 #include <string>
8 #include <vector>
9 
10 #include "caffe2/core/context.h"
11 #include "caffe2/core/operator.h"
12 #include "caffe2/utils/eigen_utils.h"
13 #include "caffe2/utils/math.h"
14 
15 namespace caffe2 {
16 
17 template <class Context>
18 class SpatialBNOp : public Operator<Context> {
19  public:
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21 
22  template <class... Args>
23  explicit SpatialBNOp(Args&&... args)
24  : Operator<Context>(std::forward<Args>(args)...),
25  OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, false),
26  OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
27  OP_SINGLE_ARG(float, "momentum", momentum_, 0.9f),
28  order_(StringToStorageOrder(
29  this->template GetSingleArgument<std::string>("order", "NCHW"))),
30  OP_SINGLE_ARG(int, "num_batches", num_batches_, 1) {
31  CAFFE_ENFORCE_NE(
32  order_,
33  StorageOrder::UNKNOWN,
34  "order should be either \"NCHW\" or \"NHWC\".");
35  CAFFE_ENFORCE(
36  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 5));
37  CAFFE_ENFORCE_GT(epsilon_, 0);
38  CAFFE_ENFORCE_GE(momentum_, 0);
39  CAFFE_ENFORCE_LE(momentum_, 1);
40  }
41 
42  virtual ~SpatialBNOp() = default;
43 
44  bool RunOnDevice() override {
45  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
46  }
47 
48  template <typename T>
49  bool DoRunWithType() {
50  const auto& X = Input(INPUT);
51  const auto& scale = Input(SCALE);
52  const auto& bias = Input(BIAS);
53 
54  const int ndim = X.dim();
55  CAFFE_ENFORCE_GE(ndim, 3);
56  const int N = X.dim32(0);
57  const int C =
58  (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
59  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
60  const int HxW =
61  std::accumulate(
62  X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
63  C;
64  CAFFE_ENFORCE_EQ(scale.numel(), C);
65  CAFFE_ENFORCE_EQ(bias.numel(), C);
66 
67  auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
68  const T* X_data = X.template data<T>();
69  const T* scale_data = scale.template data<T>();
70  const T* bias_data = bias.template data<T>();
71  T* Y_data = Y->template mutable_data<T>();
73  &alpha_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
75  &beta_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
76  T* alpha_data = alpha_.template mutable_data<T>();
77  T* beta_data = beta_.template mutable_data<T>();
78  if (is_test_) {
79  if (N == 0) {
80  return true;
81  }
82  const auto& mean = Input(EST_MEAN);
83  const auto& var = Input(EST_VAR);
84  CAFFE_ENFORCE_EQ(mean.numel(), C);
85  CAFFE_ENFORCE_EQ(var.numel(), C);
86  ComputeFusedParam<T>(
87  C,
88  scale_data,
89  bias_data,
90  mean.template data<T>(),
91  var.template data<T>(),
92  alpha_data,
93  beta_data);
94  } else {
95  auto* saved_mean = Output(SAVED_MEAN, {C}, at::dtype<T>());
96  auto* saved_rstd = Output(SAVED_INV_STD, {C}, at::dtype<T>());
97  T* saved_mean_data = saved_mean->template mutable_data<T>();
98  T* saved_rstd_data = saved_rstd->template mutable_data<T>();
99 
100  // Enforce Alias
101  CAFFE_ENFORCE(
102  IsInputOutputAlias(3, 1), "Input 3 and Output 1 should be alias.");
103  CAFFE_ENFORCE(
104  IsInputOutputAlias(4, 2), "Input 4 and Output 2 should be alias.");
105 
106  Tensor* running_mean = nullptr;
107  Tensor* running_var = nullptr;
108  const auto& mean = Input(EST_MEAN);
109  const auto& var = Input(EST_VAR);
110  if (mean.numel() != C) {
111  running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
112  C10_LOG_EVERY_MS(WARNING, 1000)
113  << "[Depreacated] Running mean is not initialized in "
114  "SpatialBatchNorm Op";
115  math::Set<T, Context>(
116  C, T(0), running_mean->template mutable_data<T>(), &context_);
117  } else {
118  running_mean = Output(RUNNING_MEAN, {C}, at::dtype<T>());
119  }
120  if (var.numel() != C) {
121  running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
122  math::Set<T, Context>(
123  C, T(0), running_var->template mutable_data<T>(), &context_);
124  C10_LOG_EVERY_MS(WARNING, 1000)
125  << "[Deprecated] Running variance is not initialized in "
126  "SpatialBatchNorm Op";
127  } else {
128  running_var = Output(RUNNING_VAR, {C}, at::dtype<T>());
129  }
130 
131  T* running_mean_data = running_mean->template mutable_data<T>();
132  T* running_var_data = running_var->template mutable_data<T>();
133  if (N == 0) {
134  math::Set<T, Context>(C, T(0), saved_mean_data, &context_);
135  math::Set<T, Context>(C, T(0), saved_rstd_data, &context_);
136  return true;
137  }
138  if (num_batches_ > 1) {
139  const auto& batch_mean_sum = Input(BATCH_MEAN_SUM);
140  const auto& batch_var_sum = Input(BATCH_VAR_SUM);
141  CAFFE_ENFORCE_EQ(batch_mean_sum.numel(), C);
142  CAFFE_ENFORCE_EQ(batch_var_sum.numel(), C);
143  ComputeBatchMoments<T>(
144  N,
145  C,
146  HxW,
147  batch_mean_sum.template data<T>(),
148  batch_var_sum.template data<T>(),
149  saved_mean_data,
150  saved_rstd_data);
151  } else {
152  if (order_ == StorageOrder::NCHW) {
153  const std::array<int, 3> X_dims_arr = {N, C, HxW};
154  const std::array<int, 3> Y_dims_arr = {1, C, 1};
155  math::Moments<T, Context>(
156  3,
157  X_dims_arr.data(),
158  Y_dims_arr.data(),
159  X_data,
160  saved_mean_data,
161  saved_rstd_data,
162  &context_);
163  } else {
164  const std::array<int, 2> X_dims_arr = {N * HxW, C};
165  const std::array<int, 2> Y_dims_arr = {1, C};
166  math::Moments<T, Context>(
167  2,
168  X_dims_arr.data(),
169  Y_dims_arr.data(),
170  X_data,
171  saved_mean_data,
172  saved_rstd_data,
173  &context_);
174  }
175  }
176  ComputeRunningMomentsAndFusedParam<T>(
177  C,
178  scale_data,
179  bias_data,
180  saved_mean_data,
181  saved_rstd_data,
182  running_mean_data,
183  running_var_data,
184  saved_rstd_data,
185  alpha_data,
186  beta_data);
187  }
188  if (order_ == StorageOrder::NCHW) {
189  math::AffineChannel<T, Context, StorageOrder::NCHW>(
190  N, C, HxW, X_data, alpha_data, beta_data, Y_data, &context_);
191  } else {
192  math::AffineChannel<T, Context, StorageOrder::NHWC>(
193  N, C, HxW, X_data, alpha_data, beta_data, Y_data, &context_);
194  }
195 
196  return true;
197  }
198 
199  protected:
200  template <typename T>
201  void ComputeFusedParam(
202  const int C,
203  const T* scale,
204  const T* bias,
205  const T* mean,
206  const T* var,
207  T* alpha,
208  T* beta) {
209  EigenVectorArrayMap<T> alpha_arr(alpha, C);
210  EigenVectorArrayMap<T> beta_arr(beta, C);
211  alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
212  (ConstEigenVectorArrayMap<T>(var, C) + static_cast<T>(epsilon_))
213  .rsqrt();
214  beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
215  alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
216  }
217 
218  template <typename T>
219  void ComputeBatchMoments(
220  const int N,
221  const int C,
222  const int HxW,
223  const T* batch_mean_sum,
224  const T* batch_var_sum,
225  T* mean,
226  T* var) {
227  const T scale = T(1) / static_cast<T>(num_batches_ * N * HxW);
228  EigenVectorArrayMap<T> mean_arr(mean, C);
229  EigenVectorArrayMap<T> var_arr(var, C);
230  mean_arr = ConstEigenVectorArrayMap<T>(batch_mean_sum, C) * scale;
231  var_arr = ConstEigenVectorArrayMap<T>(batch_var_sum, C) * scale -
232  mean_arr.square();
233  }
234 
235  template <typename T>
236  void ComputeRunningMomentsAndFusedParam(
237  const int C,
238  const T* scale,
239  const T* bias,
240  const T* mean,
241  const T* var,
242  T* running_mean,
243  T* running_var,
244  T* rstd,
245  T* alpha,
246  T* beta) {
247  const T a = T(1) - static_cast<T>(momentum_);
248  const T b = static_cast<T>(momentum_);
249  math::Axpby<T, T, Context>(C, a, mean, b, running_mean, &context_);
250  math::Axpby<T, T, Context>(C, a, var, b, running_var, &context_);
251  math::InvStd<T, Context>(C, static_cast<T>(epsilon_), var, rstd, &context_);
252  EigenVectorArrayMap<T> alpha_arr(alpha, C);
253  EigenVectorArrayMap<T> beta_arr(beta, C);
254  alpha_arr = ConstEigenVectorArrayMap<T>(scale, C) *
255  ConstEigenVectorArrayMap<T>(rstd, C);
256  beta_arr = ConstEigenVectorArrayMap<T>(bias, C) -
257  alpha_arr * ConstEigenVectorArrayMap<T>(mean, C);
258  }
259 
260  const bool is_test_;
261  double epsilon_;
262  const float momentum_;
263  const StorageOrder order_;
264  const int num_batches_;
265 
266  Tensor alpha_;
267  Tensor beta_;
268 
269  INPUT_TAGS(
270  INPUT,
271  SCALE,
272  BIAS,
273  EST_MEAN,
274  EST_VAR,
275  BATCH_MEAN_SUM,
276  BATCH_VAR_SUM);
277  OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_STD);
278 };
279 
280 template <class Context>
281 class SpatialBNGradientOp : public Operator<Context> {
282  public:
283  USE_OPERATOR_CONTEXT_FUNCTIONS;
284 
285  template <class... Args>
286  explicit SpatialBNGradientOp(Args&&... args)
287  : Operator<Context>(std::forward<Args>(args)...),
288  OP_SINGLE_ARG(double, "epsilon", epsilon_, 1e-5),
289  order_(StringToStorageOrder(
290  this->template GetSingleArgument<string>("order", "NCHW"))),
291  OP_SINGLE_ARG(int, "num_batches", num_batches_, 1) {
292  CAFFE_ENFORCE_NE(
293  order_,
294  StorageOrder::UNKNOWN,
295  "order should be either \"NCHW\" or \"NHWC\".");
296  CAFFE_ENFORCE(InputSize() == 5 || InputSize() == 7);
297  CAFFE_ENFORCE_EQ(OutputSize(), 3);
298  }
299 
300  virtual ~SpatialBNGradientOp() = default;
301 
302  bool RunOnDevice() override {
303  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
304  }
305 
306  template <typename T>
307  bool DoRunWithType() {
308  const auto& X = Input(INPUT);
309  const auto& dY = Input(OUTPUT_GRAD);
310  const auto& scale = Input(SCALE);
311  const auto& mean = Input(SAVED_MEAN);
312  const auto& rstd = Input(SAVED_INV_STD);
313  const int ndim = X.dim();
314  CAFFE_ENFORCE_GE(ndim, 3);
315  const int N = X.dim32(0);
316  const int C =
317  (order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1));
318  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
319  const int HxW =
320  std::accumulate(
321  X_dims.cbegin() + 1, X_dims.cend(), 1, std::multiplies<int>()) /
322  C;
323  CAFFE_ENFORCE_EQ(scale.numel(), C);
324  CAFFE_ENFORCE_EQ(mean.numel(), C);
325  CAFFE_ENFORCE_EQ(rstd.numel(), C);
326 
327  auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
328  at::IntArrayRef dscale_sizes, dbias_sizes;
329  if (num_batches_ == 1) {
330  dscale_sizes = scale.sizes();
331  dbias_sizes = scale.sizes();
332  } else {
333  const auto& dscale_sum = Input(AGGREGATE_SCALE_GRAD);
334  const auto& dbias_sum = Input(AGGREGATE_BIAS_GRAD);
335  // Note: previously there was alias check to decide whether to call
336  // ResizeLike or not, since we only call Resize when the size does not
337  // match the size of cached Tensor, this check is not necessary
338  dscale_sizes = dscale_sum.sizes();
339  dbias_sizes = dbias_sum.sizes();
340  }
341  auto* dscale = Output(SCALE_GRAD, dscale_sizes, at::dtype<T>());
342  auto* dbias = Output(BIAS_GRAD, dbias_sizes, at::dtype<T>());
343  const T* X_data = X.template data<T>();
344  const T* dY_data = dY.template data<T>();
345  const T* scale_data = scale.template data<T>();
346  const T* mean_data = mean.template data<T>();
347  const T* rstd_data = rstd.template data<T>();
348  T* dX_data = dX->template mutable_data<T>();
349  T* dscale_data = dscale->template mutable_data<T>();
350  T* dbias_data = dbias->template mutable_data<T>();
351 
352  if (N == 0) {
353  math::Set<T, Context>(C, T(0), dscale_data, &context_);
354  math::Set<T, Context>(C, T(0), dbias_data, &context_);
355  return true;
356  }
358  &alpha_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
360  &beta_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
362  &gamma_, {C}, at::dtype<T>().device(Context::GetDeviceType()));
363  T* alpha_data = alpha_.template mutable_data<T>();
364  T* beta_data = beta_.template mutable_data<T>();
365  T* gamma_data = gamma_.template mutable_data<T>();
366  if (num_batches_ > 1) {
367  const auto& dscale_sum = Input(AGGREGATE_SCALE_GRAD);
368  const auto& dbias_sum = Input(AGGREGATE_BIAS_GRAD);
369  ComputeMultiBatchScaleBiasGradientsAndFusedParams<T>(
370  N,
371  C,
372  HxW,
373  scale_data,
374  mean_data,
375  rstd_data,
376  dscale_sum.template data<T>(),
377  dbias_sum.template data<T>(),
378  dscale_data,
379  dbias_data,
380  alpha_data,
381  beta_data,
382  gamma_data);
383  } else {
384  ComputeScaleBiasGradientsAndFusedParams<T>(
385  N,
386  C,
387  HxW,
388  dY_data,
389  X_data,
390  scale_data,
391  mean_data,
392  rstd_data,
393  dscale_data,
394  dbias_data,
395  alpha_data,
396  beta_data,
397  gamma_data,
398  dX_data);
399  }
400  ComputeXGradient<T>(
401  N, C, HxW, dY_data, X_data, alpha_data, beta_data, gamma_data, dX_data);
402 
403  return true;
404  }
405 
406  protected:
407  template <typename T>
408  void ComputeMultiBatchScaleBiasGradientsAndFusedParams(
409  const int N,
410  const int C,
411  const int HxW,
412  const T* scale,
413  const T* mean,
414  const T* rstd,
415  const T* dscale_sum,
416  const T* dbias_sum,
417  T* dscale,
418  T* dbias,
419  T* alpha,
420  T* beta,
421  T* gamma);
422 
423  template <typename T>
424  void ComputeScaleBiasGradientsAndFusedParams(
425  const int N,
426  const int C,
427  const int HxW,
428  const T* dY,
429  const T* X,
430  const T* scale,
431  const T* mean,
432  const T* rstd,
433  T* dscale,
434  T* dbias,
435  T* alpha,
436  T* beta,
437  T* gamma,
438  T* scratch);
439 
440  template <typename T>
441  void ComputeXGradient(
442  const int N,
443  const int C,
444  const int HxW,
445  const T* dY,
446  const T* X,
447  const T* alpha,
448  const T* beta,
449  const T* gamma,
450  T* dX);
451 
452  double epsilon_;
453  const StorageOrder order_;
454  const int num_batches_;
455 
456  Tensor alpha_;
457  Tensor beta_;
458  Tensor gamma_;
459  Tensor ones_;
460 
461  INPUT_TAGS(
462  INPUT,
463  SCALE,
464  OUTPUT_GRAD,
465  SAVED_MEAN,
466  SAVED_INV_STD,
467  AGGREGATE_SCALE_GRAD,
468  AGGREGATE_BIAS_GRAD);
469  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
470 };
471 
472 } // namespace caffe2
473 
474 #endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64