Caffe2 - C++ API
A deep learning, cross platform ML framework
group_norm_op.cc
1 // ------------------------------------------------------------------
2 // GroupNorm op in Caffe2 for CPU
3 // Written by Kaiming He
4 // Improved by Xiaomeng Yang
5 // see https://arxiv.org/abs/1803.08494
6 // This is a stand-alone op: Y = gamma * (X - mu) / sig + beta
7 // ------------------------------------------------------------------
8 
9 #include "caffe2/operators/group_norm_op.h"
10 
11 namespace caffe2 {
12 
13 namespace {
14 
15 template <typename T, StorageOrder kOrder>
16 void ComputeInternalGradients(
17  const std::array<int, 4>& dims,
18  const T* dY,
19  const T* X,
20  const T* gamma,
21  T* ds,
22  T* db) {
23  constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
24  constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
25  const int size = dims[0] * dims[1] * dims[2] * dims[3];
26  std::array<int, 4> index = {0, 0, 0, 0};
27  for (int i = 0; i < size; ++i) {
28  const int i_mu = index[0] * dims[kGDim] + index[kGDim];
29  const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
30  ds[i_mu] += gamma[i_gamma] * dY[i] * X[i];
31  db[i_mu] += gamma[i_gamma] * dY[i];
32  math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
33  }
34 }
35 
36 // Math:
37 // Y = gamma * (X - mu) * rsig + beta
38 // let s = gamma * rsig
39 // let b = beta - mu * rsig
40 // Y = s * X + b
41 // let n = D * HxW
42 // dL/dX = dL/dY * dY/dX = dL/dY * (d(s * X)/dX + db/dX)
43 // d(s * X)/dX = s + X * ds/dX = s + gamma * X * drsig/dX
44 // db/dX = -u * drsig/dX - rsig * dmu/dX
45 // drsig/dX = -rsig^3 * (X - mu) / n
46 // dmu/dX = 1 / n
47 template <typename T, StorageOrder kOrder>
48 void GroupNormBackward(
49  const std::array<int, 4>& dims,
50  const T* dY,
51  const T* X,
52  const T* mu,
53  const T* rsig,
54  const T* gamma,
55  const T* ds,
56  const T* db,
57  T* dX,
58  T* dgamma,
59  T* dbeta) {
60  constexpr int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
61  constexpr int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
62  const int size = dims[0] * dims[1] * dims[2] * dims[3];
63  const int HxW = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
64  const T denom = T(1) / static_cast<T>(dims[kDDim] * HxW);
65  std::array<int, 4> index = {0, 0, 0, 0};
66  for (int i = 0; i < size; ++i) {
67  const int i_mu = index[0] * dims[kGDim] + index[kGDim];
68  const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
69  const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
70  math::utils::Cube(rsig[i_mu]);
71  const T v = db[i_mu] * rsig[i_mu];
72  dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
73  dgamma[i_gamma] += dY[i] * (X[i] - mu[i_mu]) * rsig[i_mu];
74  dbeta[i_gamma] += dY[i];
75  math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
76  }
77 }
78 
79 } // namespace
80 
81 // Math:
82 // let: s = gamma * rsig
83 // let: b = beta - mu * gamma * rsig
84 // then: Y = s * X + b
85 template <typename T, class Context>
86 bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
87  const int N,
88  const int G,
89  const int D,
90  const int HxW,
91  const T* dY_data,
92  const T* X_data,
93  const T* mu_data,
94  const T* rsig_data,
95  const T* gamma_data,
96  T* dX_data,
97  T* dgamma_data,
98  T* dbeta_data) {
99  const std::array<int, 4> dims = order_ == StorageOrder::NCHW
100  ? std::array<int, 4>{N, G, D, HxW}
101  : std::array<int, 4>{N, HxW, G, D};
102 
103  // Computes dL/ds and dL/db.
104  // dL/ds = Sum(dL/dY * gamma * X)
105  // dL/db = Sum(dL/dY * gamma)
106  const int C = G * D;
108  &ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
110  &db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
111  T* ds_data = ds_.template mutable_data<T>();
112  T* db_data = db_.template mutable_data<T>();
113  math::Set<T, Context>(N * G, T(0), ds_data, &context_);
114  math::Set<T, Context>(N * G, T(0), db_data, &context_);
115  if (order_ == StorageOrder::NCHW) {
116  ComputeInternalGradients<T, StorageOrder::NCHW>(
117  dims, dY_data, X_data, gamma_data, ds_data, db_data);
118  } else {
119  ComputeInternalGradients<T, StorageOrder::NHWC>(
120  dims, dY_data, X_data, gamma_data, ds_data, db_data);
121  }
122 
123  // Computes dL/dX, dL/dgamma and dL/dbeta.
124  math::Set<T, Context>(C, T(0), dgamma_data, &context_);
125  math::Set<T, Context>(C, T(0), dbeta_data, &context_);
126  if (order_ == StorageOrder::NCHW) {
127  GroupNormBackward<T, StorageOrder::NCHW>(
128  dims,
129  dY_data,
130  X_data,
131  mu_data,
132  rsig_data,
133  gamma_data,
134  ds_data,
135  db_data,
136  dX_data,
137  dgamma_data,
138  dbeta_data);
139  } else {
140  GroupNormBackward<T, StorageOrder::NHWC>(
141  dims,
142  dY_data,
143  X_data,
144  mu_data,
145  rsig_data,
146  gamma_data,
147  ds_data,
148  db_data,
149  dX_data,
150  dgamma_data,
151  dbeta_data);
152  }
153  return true;
154 }
155 
156 REGISTER_CPU_OPERATOR(GroupNorm, GroupNormOp<float, CPUContext>);
157 REGISTER_CPU_OPERATOR(
158  GroupNormGradient,
159  GroupNormGradientOp<float, CPUContext>);
160 
161 // Warning: mu and rsig are for backward usage or reference. They should NOT be
162 // used as forward activations as they have no direct gradients computed.
163 
164 // Input: X, gamma, beta; Output: Y, mu, sig
165 OPERATOR_SCHEMA(GroupNorm)
166  .NumInputs(3)
167  .NumOutputs({1, 3})
168  .SetDoc(R"DOC(
169 Group Normalization (GN) operation: https://arxiv.org/abs/1803.08494
170 )DOC")
171  .Arg("num_groups", "(int) default 32; number of groups used by GN.")
172  .Arg("epsilon", "(float) default 1e-5; small constant added to var.")
173  .Input(
174  0,
175  "X",
176  ">=4D feature map input of shape (N, C, H, W) or (N, C, T, H, W)")
177  .Input(
178  1,
179  "gamma",
180  "The scale as a 1-dimensional tensor of size C to be applied to the "
181  "output.")
182  .Input(
183  2,
184  "beta",
185  "The bias as a 1-dimensional tensor of size C to be applied to the "
186  "output.")
187  .Output(0, "Y", "The output >=4-dimensional tensor of the same shape as X.")
188  .Output(
189  1,
190  "mean",
191  "The mean of shape (N, G). "
192  "For backward usage or reference. "
193  "Cannot be used as activations.")
194  .Output(
195  2,
196  "std",
197  "The std of shape (N, G). "
198  "For backward usage or reference. "
199  "Cannot be used as activations.");
200 
201 // Input: dY, X, gamma, beta, mu, sig; Output: dX, dgamma, dbeta
202 OPERATOR_SCHEMA(GroupNormGradient).NumInputs(6).NumOutputs(3);
203 
204 class GetGroupNormGradient : public GradientMakerBase {
205  using GradientMakerBase::GradientMakerBase;
206  vector<OperatorDef> GetGradientDefs() override {
207  return SingleGradientDef(
208  "GroupNormGradient",
209  "",
210  vector<string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
211  vector<string>{GI(0), GI(1), GI(2)});
212  }
213 };
214 
215 REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient);
216 
217 } // namespace caffe2
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:70