1 #ifndef CAFFE2_OPERATORS_GROUP_NORM_OP_H_ 2 #define CAFFE2_OPERATORS_GROUP_NORM_OP_H_ 8 #include "caffe2/core/common.h" 9 #include "caffe2/core/context.h" 10 #include "caffe2/core/operator.h" 11 #include "caffe2/utils/eigen_utils.h" 12 #include "caffe2/utils/math.h" 16 template <
typename T,
class Context>
19 USE_OPERATOR_CONTEXT_FUNCTIONS;
21 template <
class... Args>
24 OP_SINGLE_ARG(
int,
"group", group_, 32),
25 OP_SINGLE_ARG(
float,
"epsilon", epsilon_, 1e-5),
26 order_(StringToStorageOrder(
27 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))),
28 OP_SINGLE_ARG(
bool, OpSchema::Arg_IsTest, is_test_,
true) {
31 StorageOrder::UNKNOWN,
32 "order should be either \"NCHW\" or \"NHWC\".");
34 CAFFE_ENFORCE_EQ(OutputSize(), 3);
38 bool RunOnDevice()
override {
39 const auto& X =
Input(INPUT);
40 const auto& gamma =
Input(GAMMA);
41 const auto& beta =
Input(BETA);
42 const int ndim = X.dim();
43 const int N = X.dim32(0);
44 const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
45 const int HxW = X.numel() / (N * C);
46 CAFFE_ENFORCE_EQ(C % group_, 0);
47 CAFFE_ENFORCE_EQ(gamma.numel(), C);
48 CAFFE_ENFORCE_EQ(beta.numel(), C);
52 auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
54 T* rsig_data =
nullptr;
55 if (OutputSize() == 3) {
56 auto* mu = Output(MU, {N, G}, at::dtype<T>());
57 auto* rsig = Output(INV_SIGMA, {N, G}, at::dtype<T>());
58 mu_data = mu->template mutable_data<T>();
59 rsig_data = rsig->template mutable_data<T>();
62 &mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
64 &rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
65 mu_data = mu_.template mutable_data<T>();
66 rsig_data = rsig_.template mutable_data<T>();
68 return RunOnDeviceImpl(
74 gamma.template data<T>(),
75 beta.template data<T>(),
76 Y->template mutable_data<T>(),
95 &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
97 &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
98 T* scale_data = scale_.template mutable_data<T>();
99 T* bias_data = bias_.template mutable_data<T>();
100 if (order_ == StorageOrder::NCHW) {
101 const std::array<int, 2> X_dims = {N * G, D * HxW};
102 const std::array<int, 2> Y_dims = {N * G, 1};
103 math::Moments<T, Context>(
104 2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
105 math::InvStd<T, Context>(
106 N * G,
static_cast<T>(epsilon_), rsig, rsig, &context_);
107 ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
108 GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
110 const std::array<int, 4> X_dims = {N, HxW, G, D};
111 const std::array<int, 4> Y_dims = {N, 1, G, 1};
112 math::Moments<T, Context>(
113 4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
114 math::InvStd<T, Context>(
115 N * G,
static_cast<T>(epsilon_), rsig, rsig, &context_);
116 ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
117 GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
122 void ComputeFusedParams(
133 ConstEigenArrayMap<float> gamma_arr(gamma, D, G);
134 ConstEigenArrayMap<float> beta_arr(beta, D, G);
135 for (
int i = 0; i < N; ++i) {
136 EigenArrayMap<T> scale_arr(scale + i * C, D, G);
137 scale_arr = gamma_arr.rowwise() *
138 ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
139 EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
140 scale_arr.rowwise() *
141 ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
145 void GroupNormForwardNCHW(
153 EigenArrayMap<float>(Y, HxW, N * C) =
154 (ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
155 ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
157 ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
160 void GroupNormForwardNHWC(
168 const int stride = HxW * C;
169 for (
int i = 0; i < N; ++i) {
170 EigenArrayMap<float>(Y + i * stride, C, HxW) =
171 (ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
172 ConstEigenVectorArrayMap<float>(scale + i * C, C))
174 ConstEigenVectorArrayMap<float>(bias + i * C, C);
179 const float epsilon_;
180 const StorageOrder order_;
190 INPUT_TAGS(INPUT, GAMMA, BETA);
191 OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA);
194 template <
typename T,
class Context>
197 USE_OPERATOR_CONTEXT_FUNCTIONS;
199 template <
class... Args>
202 OP_SINGLE_ARG(
int,
"group", group_, 32),
203 order_(StringToStorageOrder(
204 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))) {
207 StorageOrder::UNKNOWN,
208 "order should be either \"NCHW\" or \"NHWC\".");
211 bool RunOnDevice()
override {
212 const auto& dY =
Input(OUTPUT_GRAD);
213 const auto& X =
Input(INPUT);
214 const auto& gamma =
Input(GAMMA);
215 const auto& beta =
Input(BETA);
216 const auto& mu =
Input(MU);
217 const auto& rsig =
Input(INV_SIGMA);
218 const int ndim = X.dim();
219 const int N = X.dim32(0);
220 const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
221 const int HxW = X.numel() / (N * C);
222 CAFFE_ENFORCE_EQ(C % group_, 0);
223 CAFFE_ENFORCE_EQ(gamma.numel(), C);
224 CAFFE_ENFORCE_EQ(beta.numel(), C);
225 const int G = group_;
228 auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
229 auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
230 auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
231 return RunOnDeviceImpl(
236 dY.template data<T>(),
237 X.template data<T>(),
238 mu.template data<T>(),
239 rsig.template data<T>(),
240 gamma.template data<T>(),
241 dX->template mutable_data<T>(),
242 dgamma->template mutable_data<T>(),
243 dbeta->template mutable_data<T>());
247 bool RunOnDeviceImpl(
262 const StorageOrder order_;
269 INPUT_TAGS(OUTPUT_GRAD, INPUT, GAMMA, BETA, MU, INV_SIGMA);
270 OUTPUT_TAGS(INPUT_GRAD, GAMMA_GRAD, BETA_GRAD);
275 #endif // CAFFE2_OPERATORS_GROUP_NORM_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...