Caffe2 - C++ API
A deep learning, cross platform ML framework
group_norm_op.h
1 #ifndef CAFFE2_OPERATORS_GROUP_NORM_OP_H_
2 #define CAFFE2_OPERATORS_GROUP_NORM_OP_H_
3 
4 #include <array>
5 #include <string>
6 #include <vector>
7 
8 #include "caffe2/core/common.h"
9 #include "caffe2/core/context.h"
10 #include "caffe2/core/operator.h"
11 #include "caffe2/utils/eigen_utils.h"
12 #include "caffe2/utils/math.h"
13 
14 namespace caffe2 {
15 
16 template <typename T, class Context>
17 class GroupNormOp final : public Operator<Context> {
18  public:
19  USE_OPERATOR_CONTEXT_FUNCTIONS;
20 
21  template <class... Args>
22  explicit GroupNormOp(Args&&... args)
23  : Operator<Context>(std::forward<Args>(args)...),
24  OP_SINGLE_ARG(int, "group", group_, 32),
25  OP_SINGLE_ARG(float, "epsilon", epsilon_, 1e-5),
26  order_(StringToStorageOrder(
27  this->template GetSingleArgument<std::string>("order", "NCHW"))),
28  OP_SINGLE_ARG(bool, OpSchema::Arg_IsTest, is_test_, true) {
29  CAFFE_ENFORCE_NE(
30  order_,
31  StorageOrder::UNKNOWN,
32  "order should be either \"NCHW\" or \"NHWC\".");
33  if (!is_test_) {
34  CAFFE_ENFORCE_EQ(OutputSize(), 3);
35  }
36  }
37 
38  bool RunOnDevice() override {
39  const auto& X = Input(INPUT);
40  const auto& gamma = Input(GAMMA);
41  const auto& beta = Input(BETA);
42  const int ndim = X.dim();
43  const int N = X.dim32(0);
44  const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
45  const int HxW = X.numel() / (N * C);
46  CAFFE_ENFORCE_EQ(C % group_, 0);
47  CAFFE_ENFORCE_EQ(gamma.numel(), C);
48  CAFFE_ENFORCE_EQ(beta.numel(), C);
49  const int G = group_;
50  const int D = C / G;
51 
52  auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
53  T* mu_data = nullptr;
54  T* rsig_data = nullptr;
55  if (OutputSize() == 3) {
56  auto* mu = Output(MU, {N, G}, at::dtype<T>());
57  auto* rsig = Output(INV_SIGMA, {N, G}, at::dtype<T>());
58  mu_data = mu->template mutable_data<T>();
59  rsig_data = rsig->template mutable_data<T>();
60  } else {
62  &mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
64  &rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
65  mu_data = mu_.template mutable_data<T>();
66  rsig_data = rsig_.template mutable_data<T>();
67  }
68  return RunOnDeviceImpl(
69  N,
70  G,
71  D,
72  HxW,
73  X.template data<T>(),
74  gamma.template data<T>(),
75  beta.template data<T>(),
76  Y->template mutable_data<T>(),
77  mu_data,
78  rsig_data);
79  }
80 
81  protected:
82  bool RunOnDeviceImpl(
83  const int N,
84  const int G,
85  const int D,
86  const int HxW,
87  const T* X,
88  const T* gamma,
89  const T* beta,
90  T* Y,
91  T* mu,
92  T* rsig) {
93  const int C = G * D;
95  &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
97  &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
98  T* scale_data = scale_.template mutable_data<T>();
99  T* bias_data = bias_.template mutable_data<T>();
100  if (order_ == StorageOrder::NCHW) {
101  const std::array<int, 2> X_dims = {N * G, D * HxW};
102  const std::array<int, 2> Y_dims = {N * G, 1};
103  math::Moments<T, Context>(
104  2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
105  math::InvStd<T, Context>(
106  N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
107  ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
108  GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
109  } else {
110  const std::array<int, 4> X_dims = {N, HxW, G, D};
111  const std::array<int, 4> Y_dims = {N, 1, G, 1};
112  math::Moments<T, Context>(
113  4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
114  math::InvStd<T, Context>(
115  N * G, static_cast<T>(epsilon_), rsig, rsig, &context_);
116  ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
117  GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
118  }
119  return true;
120  }
121 
122  void ComputeFusedParams(
123  const int N,
124  const int G,
125  const int D,
126  const T* mu,
127  const T* rsig,
128  const T* gamma,
129  const T* beta,
130  T* scale,
131  T* bias) {
132  const int C = G * D;
133  ConstEigenArrayMap<float> gamma_arr(gamma, D, G);
134  ConstEigenArrayMap<float> beta_arr(beta, D, G);
135  for (int i = 0; i < N; ++i) {
136  EigenArrayMap<T> scale_arr(scale + i * C, D, G);
137  scale_arr = gamma_arr.rowwise() *
138  ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
139  EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
140  scale_arr.rowwise() *
141  ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
142  }
143  }
144 
145  void GroupNormForwardNCHW(
146  const int N,
147  const int C,
148  const int HxW,
149  const T* X,
150  const T* scale,
151  const T* bias,
152  T* Y) {
153  EigenArrayMap<float>(Y, HxW, N * C) =
154  (ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
155  ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
156  .rowwise() +
157  ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
158  }
159 
160  void GroupNormForwardNHWC(
161  const int N,
162  const int C,
163  const int HxW,
164  const T* X,
165  const T* scale,
166  const T* bias,
167  T* Y) {
168  const int stride = HxW * C;
169  for (int i = 0; i < N; ++i) {
170  EigenArrayMap<float>(Y + i * stride, C, HxW) =
171  (ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
172  ConstEigenVectorArrayMap<float>(scale + i * C, C))
173  .colwise() +
174  ConstEigenVectorArrayMap<float>(bias + i * C, C);
175  }
176  }
177 
178  const int group_;
179  const float epsilon_;
180  const StorageOrder order_;
181  const bool is_test_;
182 
183  Tensor mu_;
184  Tensor rsig_;
185  Tensor scale_;
186  Tensor bias_;
187 
188  // Input: X, gamma, beta
189  // Output: Y, mu, inv_sig
190  INPUT_TAGS(INPUT, GAMMA, BETA);
191  OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA);
192 };
193 
194 template <typename T, class Context>
195 class GroupNormGradientOp final : public Operator<Context> {
196  public:
197  USE_OPERATOR_CONTEXT_FUNCTIONS;
198 
199  template <class... Args>
200  explicit GroupNormGradientOp(Args&&... args)
201  : Operator<Context>(std::forward<Args>(args)...),
202  OP_SINGLE_ARG(int, "group", group_, 32),
203  order_(StringToStorageOrder(
204  this->template GetSingleArgument<std::string>("order", "NCHW"))) {
205  CAFFE_ENFORCE_NE(
206  order_,
207  StorageOrder::UNKNOWN,
208  "order should be either \"NCHW\" or \"NHWC\".");
209  }
210 
211  bool RunOnDevice() override {
212  const auto& dY = Input(OUTPUT_GRAD);
213  const auto& X = Input(INPUT);
214  const auto& gamma = Input(GAMMA);
215  const auto& beta = Input(BETA);
216  const auto& mu = Input(MU);
217  const auto& rsig = Input(INV_SIGMA);
218  const int ndim = X.dim();
219  const int N = X.dim32(0);
220  const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
221  const int HxW = X.numel() / (N * C);
222  CAFFE_ENFORCE_EQ(C % group_, 0);
223  CAFFE_ENFORCE_EQ(gamma.numel(), C);
224  CAFFE_ENFORCE_EQ(beta.numel(), C);
225  const int G = group_;
226  const int D = C / G;
227 
228  auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
229  auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
230  auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
231  return RunOnDeviceImpl(
232  N,
233  G,
234  D,
235  HxW,
236  dY.template data<T>(),
237  X.template data<T>(),
238  mu.template data<T>(),
239  rsig.template data<T>(),
240  gamma.template data<T>(),
241  dX->template mutable_data<T>(),
242  dgamma->template mutable_data<T>(),
243  dbeta->template mutable_data<T>());
244  }
245 
246  protected:
247  bool RunOnDeviceImpl(
248  const int N,
249  const int G,
250  const int D,
251  const int HxW,
252  const T* dY_data,
253  const T* X_data,
254  const T* mu_data,
255  const T* rsig_data,
256  const T* gamma_data,
257  T* dX_data,
258  T* dgamma_data,
259  T* dbeta_data);
260 
261  const int group_;
262  const StorageOrder order_;
263 
264  Tensor ds_;
265  Tensor db_;
266 
267  // Input: dY, X, gamma, beta, mu, inv_sig
268  // Output: dX, dgamma, dbeta
269  INPUT_TAGS(OUTPUT_GRAD, INPUT, GAMMA, BETA, MU, INV_SIGMA);
270  OUTPUT_TAGS(INPUT_GRAD, GAMMA_GRAD, BETA_GRAD);
271 };
272 
273 } // namespace caffe2
274 
275 #endif // CAFFE2_OPERATORS_GROUP_NORM_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:70