Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_batch_norm_op.cc
1 #include <caffe2/ideep/ideep_utils.h>
2 
3 namespace caffe2 {
4 
5 class IDEEPSpatialBNOp final : public IDEEPOperator {
6  public:
7  USE_IDEEP_DEF_ALIASES();
8  USE_IDEEP_OPERATOR_FUNCTIONS();
9 
10  IDEEPSpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
11  : IDEEPOperator(operator_def, ws),
12  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
13  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)),
14  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9)) {
15  CAFFE_ENFORCE(
16  (is_test_ && OutputSize() > OUTPUT)
17  || (!is_test_ && OutputSize() > SAVED_VAR));
18  CAFFE_ENFORCE_GT(epsilon_, 0);
19  CAFFE_ENFORCE_GE(momentum_, 0);
20  CAFFE_ENFORCE_LE(momentum_, 1);
21  }
22  ~IDEEPSpatialBNOp() override {}
23 
24  bool RunOnDevice() override {
25  const auto& X = Input(INPUT);
26  const auto& scale = Input(SCALE);
27  const auto& bias = Input(BIAS);
28  auto* Y = Output(OUTPUT);
29 
30  DCHECK_EQ(scale.ndims(), 1);
31  DCHECK_EQ(bias.ndims(), 1);
32  DCHECK_EQ(scale.get_dim(0), X.get_dim(1));
33  DCHECK_EQ(bias.get_dim(0), X.get_dim(1));
34 
35  if (is_test_) {
36  const auto& est_mean = Input(EST_MEAN);
37  const auto& est_var = Input(EST_VAR);
38  ideep::batch_normalization_forward_inference::compute(
39  X, est_mean, est_var, scale, bias, *Y, epsilon_);
40  } else {
41  auto* saved_mean = Output(SAVED_MEAN);
42  auto* saved_var = Output(SAVED_VAR);
43  auto* running_mean = Output(RUNNING_MEAN);
44  auto* running_var = Output(RUNNING_VAR);
45  ideep::batch_normalization_forward_training::compute(
46  X, scale, bias, *Y, *saved_mean, *saved_var,
47  *running_mean, *running_var, momentum_, epsilon_);
48  }
49 
50  return true;
51  }
52 
53  private:
54  bool is_test_;
55  double epsilon_;
56  double momentum_;
57 
58  INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR);
59  OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_VAR);
60 };
61 
63  public:
64  USE_IDEEP_DEF_ALIASES();
65  USE_IDEEP_OPERATOR_FUNCTIONS();
66 
67  IDEEPSpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
68  : IDEEPOperator(operator_def, ws),
69  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5)) {
70  CAFFE_ENFORCE(InputSize() > SAVED_VAR);
71  CAFFE_ENFORCE(OutputSize() > BIAS_GRAD);
72  }
73  ~IDEEPSpatialBNGradientOp() override {}
74 
75  bool RunOnDevice() override {
76  const auto& X = Input(INPUT);
77  const auto& scale = Input(SCALE);
78  const auto& dY = Input(OUTPUT_GRAD);
79  const auto& saved_mean = Input(SAVED_MEAN);
80  const auto& saved_var = Input(SAVED_VAR);
81  auto* dX = Output(INPUT_GRAD);
82  auto* dscale = Output(SCALE_GRAD);
83  auto* dbias = Output(BIAS_GRAD);
84 
85  ideep::batch_normalization_backward::compute(
86  X, saved_mean, saved_var, dY, scale,
87  *dX, *dscale, *dbias, epsilon_);
88 
89  return true;
90  }
91 
92  private:
93  double epsilon_;
94 
95  INPUT_TAGS(INPUT, SCALE, OUTPUT_GRAD, SAVED_MEAN, SAVED_VAR);
96  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
97 };
98 
99 REGISTER_IDEEP_OPERATOR(SpatialBN, IDEEPSpatialBNOp);
100 REGISTER_IDEEP_OPERATOR(SpatialBNGradient, IDEEPSpatialBNGradientOp)
101 
102 } // namespace caffe2
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13