Caffe2 - C++ API
A deep learning, cross platform ML framework
spatial_batch_norm_op.h
1 
17 #ifndef CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
18 #define CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <class Context>
27 class SpatialBNOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  SpatialBNOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
33  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)),
34  momentum_(OperatorBase::GetSingleArgument<float>("momentum", 0.9f)),
35  order_(StringToStorageOrder(
36  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
37  num_batches_(OperatorBase::GetSingleArgument<int>("num_batches", 1)) {
38  // TODO(jiayq): update the input and output size checks.
39  CAFFE_ENFORCE(
40  (is_test_ && OutputSize() == 1) || (!is_test_ && OutputSize() == 5));
41  CAFFE_ENFORCE_GT(epsilon_, 0);
42  CAFFE_ENFORCE_GE(momentum_, 0);
43  CAFFE_ENFORCE_LE(momentum_, 1);
44  }
45  ~SpatialBNOp() {}
46 
47  bool RunOnDevice() override {
48  return true;
49  }
50 
51  protected:
52  bool is_test_;
53  double epsilon_;
54  double momentum_;
55  StorageOrder order_;
56  int num_batches_;
57  INPUT_TAGS(INPUT, SCALE, BIAS, EST_MEAN, EST_VAR, SUMS, SUMSQ);
58  OUTPUT_TAGS(OUTPUT, RUNNING_MEAN, RUNNING_VAR, SAVED_MEAN, SAVED_INV_VAR);
59 };
60 
61 template <class Context>
62 class SpatialBNGradientOp : public Operator<Context> {
63  public:
64  USE_OPERATOR_CONTEXT_FUNCTIONS;
65  SpatialBNGradientOp(const OperatorDef& operator_def, Workspace* ws)
66  : Operator<Context>(operator_def, ws),
67  is_test_(OperatorBase::GetSingleArgument<int>(OpSchema::Arg_IsTest, 0)),
68  epsilon_(OperatorBase::GetSingleArgument<float>("epsilon", 1e-5f)),
69  order_(StringToStorageOrder(
70  OperatorBase::GetSingleArgument<string>("order", "NCHW"))),
71  num_batches_(OperatorBase::GetSingleArgument<int>("num_batches", 1)) {
72  CAFFE_ENFORCE(InputSize() == 5 || InputSize() == 7);
73  CAFFE_ENFORCE(OutputSize() == 3);
74  }
76 
77  bool RunOnDevice() override {
78  return true;
79  }
80 
81  protected:
82  bool is_test_;
83  double epsilon_;
84  StorageOrder order_;
85  int num_batches_;
86 
87  INPUT_TAGS(
88  INPUT,
89  SCALE,
90  OUTPUT_GRAD,
91  SAVED_MEAN,
92  SAVED_INV_VAR,
93  AGGREGATE_SCALE_GRAD,
94  AGGREGATE_BIAS_GRAD);
95  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
96 };
97 
98 } // namespace caffe2
99 
100 #endif // CAFFE2_OPERATORS_SPATIAL_BATCH_NORM_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.