9 #include "caffe2/operators/group_norm_op.h" 15 template <
typename T, StorageOrder kOrder>
16 void ComputeInternalGradients(
17 const std::array<int, 4>& dims,
23 constexpr
int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
24 constexpr
int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
25 const int size = dims[0] * dims[1] * dims[2] * dims[3];
26 std::array<int, 4> index = {0, 0, 0, 0};
27 for (
int i = 0; i < size; ++i) {
28 const int i_mu = index[0] * dims[kGDim] + index[kGDim];
29 const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
30 ds[i_mu] += gamma[i_gamma] * dY[i] * X[i];
31 db[i_mu] += gamma[i_gamma] * dY[i];
32 math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
47 template <
typename T, StorageOrder kOrder>
48 void GroupNormBackward(
49 const std::array<int, 4>& dims,
60 constexpr
int kGDim = kOrder == StorageOrder::NCHW ? 1 : 2;
61 constexpr
int kDDim = kOrder == StorageOrder::NCHW ? 2 : 3;
62 const int size = dims[0] * dims[1] * dims[2] * dims[3];
63 const int HxW = kOrder == StorageOrder::NCHW ? dims[3] : dims[1];
64 const T denom =
T(1) /
static_cast<T>(dims[kDDim] * HxW);
65 std::array<int, 4> index = {0, 0, 0, 0};
66 for (
int i = 0; i < size; ++i) {
67 const int i_mu = index[0] * dims[kGDim] + index[kGDim];
68 const int i_gamma = index[kGDim] * dims[kDDim] + index[kDDim];
69 const T u = (db[i_mu] * mu[i_mu] - ds[i_mu]) * (X[i] - mu[i_mu]) *
70 math::utils::Cube(rsig[i_mu]);
71 const T v = db[i_mu] * rsig[i_mu];
72 dX[i] = gamma[i_gamma] * dY[i] * rsig[i_mu] + (u - v) * denom;
73 dgamma[i_gamma] += dY[i] * (X[i] - mu[i_mu]) * rsig[i_mu];
74 dbeta[i_gamma] += dY[i];
75 math::utils::IncreaseIndexInDims(4, dims.data(), index.data());
85 template <
typename T,
class Context>
86 bool GroupNormGradientOp<T, Context>::RunOnDeviceImpl(
99 const std::array<int, 4> dims = order_ == StorageOrder::NCHW
100 ? std::array<int, 4>{N, G, D, HxW}
101 : std::array<int, 4>{N, HxW, G, D};
108 &ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
110 &db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
111 T* ds_data = ds_.template mutable_data<T>();
112 T* db_data = db_.template mutable_data<T>();
113 math::Set<T, Context>(N * G,
T(0), ds_data, &context_);
114 math::Set<T, Context>(N * G,
T(0), db_data, &context_);
115 if (order_ == StorageOrder::NCHW) {
116 ComputeInternalGradients<T, StorageOrder::NCHW>(
117 dims, dY_data, X_data, gamma_data, ds_data, db_data);
119 ComputeInternalGradients<T, StorageOrder::NHWC>(
120 dims, dY_data, X_data, gamma_data, ds_data, db_data);
124 math::Set<T, Context>(C,
T(0), dgamma_data, &context_);
125 math::Set<T, Context>(C,
T(0), dbeta_data, &context_);
126 if (order_ == StorageOrder::NCHW) {
127 GroupNormBackward<T, StorageOrder::NCHW>(
140 GroupNormBackward<T, StorageOrder::NHWC>(
156 REGISTER_CPU_OPERATOR(GroupNorm, GroupNormOp<float, CPUContext>);
157 REGISTER_CPU_OPERATOR(
159 GroupNormGradientOp<float, CPUContext>);
165 OPERATOR_SCHEMA(GroupNorm)
169 Group Normalization (GN) operation: https://arxiv.org/abs/1803.08494 171 .Arg("num_groups",
"(int) default 32; number of groups used by GN.")
172 .Arg(
"epsilon",
"(float) default 1e-5; small constant added to var.")
176 ">=4D feature map input of shape (N, C, H, W) or (N, C, T, H, W)")
180 "The scale as a 1-dimensional tensor of size C to be applied to the " 185 "The bias as a 1-dimensional tensor of size C to be applied to the " 187 .Output(0,
"Y",
"The output >=4-dimensional tensor of the same shape as X.")
191 "The mean of shape (N, G). " 192 "For backward usage or reference. " 193 "Cannot be used as activations.")
197 "The std of shape (N, G). " 198 "For backward usage or reference. " 199 "Cannot be used as activations.");
202 OPERATOR_SCHEMA(GroupNormGradient).NumInputs(6).NumOutputs(3);
204 class GetGroupNormGradient :
public GradientMakerBase {
205 using GradientMakerBase::GradientMakerBase;
206 vector<OperatorDef> GetGradientDefs()
override {
207 return SingleGradientDef(
210 vector<string>{GO(0), I(0), I(1), I(2), O(1), O(2)},
211 vector<string>{GI(0), GI(1), GI(2)});
215 REGISTER_GRADIENT(GroupNorm, GetGroupNormGradient);
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...