Caffe2 - C++ API
A deep learning, cross platform ML framework
instance_norm_op.h
1 #ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
2 #define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/utils/math.h"
7 
8 namespace caffe2 {
9 
10 template <typename T, class Context>
11 class InstanceNormOp : public Operator<Context> {
12  public:
13  USE_OPERATOR_CONTEXT_FUNCTIONS;
14  template <class... Args>
15  explicit InstanceNormOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  epsilon_(this->template GetSingleArgument<T>("epsilon", 1e-5f)),
18  order_(StringToStorageOrder(
19  this->template GetSingleArgument<string>("order", "NCHW"))) {
20  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
21  }
22  ~InstanceNormOp() {}
23 
24  bool RunOnDevice() {
25  switch (order_) {
26  case StorageOrder::NHWC:
27  return RunOnDeviceWithOrderNHWC();
28  case StorageOrder::NCHW:
29  return RunOnDeviceWithOrderNCHW();
30  default:
31  CAFFE_THROW("Unknown storage order: ", order_);
32  }
33  }
34 
35  bool RunOnDeviceWithOrderNHWC();
36  bool RunOnDeviceWithOrderNCHW();
37 
38  protected:
39  // parameters
40  T epsilon_;
41  StorageOrder order_;
42 
43  // temp results that get passed to the gradient, but are otherwise stored here
44  Tensor mean_{Context::GetDeviceType()};
45  Tensor inv_stdev_{Context::GetDeviceType()};
46 
47  INPUT_TAGS(INPUT, SCALE, BIAS);
48  OUTPUT_TAGS(OUTPUT, MEAN, INV_STDEV);
49 };
50 
51 template <typename T, class Context>
52 class InstanceNormGradientOp : public Operator<Context> {
53  public:
54  USE_OPERATOR_CONTEXT_FUNCTIONS;
55  template <class... Args>
56  explicit InstanceNormGradientOp(Args&&... args)
57  : Operator<Context>(std::forward<Args>(args)...),
58  epsilon_(this->template GetSingleArgument<T>("epsilon", 1e-5f)),
59  order_(StringToStorageOrder(
60  this->template GetSingleArgument<string>("order", "NCHW"))) {
61  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
62  }
63  ~InstanceNormGradientOp() {}
64 
65  bool RunOnDevice() {
66  switch (order_) {
67  case StorageOrder::NHWC:
68  return RunOnDeviceWithOrderNHWC();
69  case StorageOrder::NCHW:
70  return RunOnDeviceWithOrderNCHW();
71  default:
72  CAFFE_THROW("Unknown storage order: ", order_);
73  }
74  }
75 
76  bool RunOnDeviceWithOrderNHWC();
77  bool RunOnDeviceWithOrderNCHW();
78 
79  protected:
80  // parameters
81  T epsilon_;
82  StorageOrder order_;
83 
84  // temp results that could get passed through to this gradient, but if not,
85  // are stored here
86  Tensor mean_;
87  Tensor inv_stdev_;
88 
89  INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
90  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
91 };
92 
93 } // namespace caffe2
94 
95 #endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13