Caffe2 - C++ API
A deep learning, cross platform ML framework
instance_norm_op.h
1 
17 #ifndef CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
18 #define CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
19 
20 #include "caffe2/core/context.h"
21 #include "caffe2/core/operator.h"
22 #include "caffe2/utils/math.h"
23 
24 namespace caffe2 {
25 
26 template <typename T, class Context>
27 class InstanceNormOp : public Operator<Context> {
28  public:
29  USE_OPERATOR_CONTEXT_FUNCTIONS;
30  InstanceNormOp(const OperatorDef& operator_def, Workspace* ws)
31  : Operator<Context>(operator_def, ws),
32  epsilon_(OperatorBase::GetSingleArgument<T>("epsilon", 1e-5f)),
33  order_(StringToStorageOrder(
34  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
35  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
36  }
37  ~InstanceNormOp() {}
38 
39  bool RunOnDevice() {
40  switch (order_) {
41  case StorageOrder::NHWC:
42  return RunOnDeviceWithOrderNHWC();
43  case StorageOrder::NCHW:
44  return RunOnDeviceWithOrderNCHW();
45  default:
46  CAFFE_THROW("Unknown storage order: ", order_);
47  }
48  }
49 
50  bool RunOnDeviceWithOrderNHWC();
51  bool RunOnDeviceWithOrderNCHW();
52 
53  protected:
54  // parameters
55  T epsilon_;
56  StorageOrder order_;
57 
58  // temp results that get passed to the gradient, but are otherwise stored here
59  Tensor<Context> mean_;
60  Tensor<Context> inv_stdev_;
61 
62  INPUT_TAGS(INPUT, SCALE, BIAS);
63  OUTPUT_TAGS(OUTPUT, MEAN, INV_STDEV);
64 };
65 
66 template <typename T, class Context>
67 class InstanceNormGradientOp : public Operator<Context> {
68  public:
69  USE_OPERATOR_CONTEXT_FUNCTIONS;
70  InstanceNormGradientOp(const OperatorDef& operator_def, Workspace* ws)
71  : Operator<Context>(operator_def, ws),
72  epsilon_(OperatorBase::GetSingleArgument<T>("epsilon", 1e-5f)),
73  order_(StringToStorageOrder(
74  OperatorBase::GetSingleArgument<string>("order", "NCHW"))) {
75  CAFFE_ENFORCE(epsilon_ >= 0, "Must pass a nonnegative epsilon.");
76  }
78 
79  bool RunOnDevice() {
80  switch (order_) {
81  case StorageOrder::NHWC:
82  return RunOnDeviceWithOrderNHWC();
83  case StorageOrder::NCHW:
84  return RunOnDeviceWithOrderNCHW();
85  default:
86  CAFFE_THROW("Unknown storage order: ", order_);
87  }
88  }
89 
90  bool RunOnDeviceWithOrderNHWC();
91  bool RunOnDeviceWithOrderNCHW();
92 
93  protected:
94  // parameters
95  T epsilon_;
96  StorageOrder order_;
97 
98  // temp results that could get passed through to this gradient, but if not,
99  // are stored here
100  Tensor<Context> mean_;
101  Tensor<Context> inv_stdev_;
102 
103  INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
104  OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
105 };
106 
107 } // namespace caffe2
108 
109 #endif // CAFFE2_OPERATORS_INSTANCE_NORM_OP_H_
Tensor is the basic class in Caffe2 that stores a contiguous memory with its shape information...
Definition: tensor.h:109
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:63
Copyright (c) 2016-present, Facebook, Inc.