1 #ifndef CAFFE2_OPERATORS_GROUP_NORM_OP_H_     2 #define CAFFE2_OPERATORS_GROUP_NORM_OP_H_     8 #include "caffe2/core/common.h"     9 #include "caffe2/core/context.h"    10 #include "caffe2/core/operator.h"    11 #include "caffe2/utils/eigen_utils.h"    12 #include "caffe2/utils/math.h"    16 template <
typename T, 
class Context>
    19   USE_OPERATOR_CONTEXT_FUNCTIONS;
    21   template <
class... Args>
    24         OP_SINGLE_ARG(
int, 
"group", group_, 32),
    25         OP_SINGLE_ARG(
float, 
"epsilon", epsilon_, 1e-5),
    26         order_(StringToStorageOrder(
    27             this->
template GetSingleArgument<std::string>(
"order", 
"NCHW"))),
    28         OP_SINGLE_ARG(
bool, OpSchema::Arg_IsTest, is_test_, 
true) {
    31         StorageOrder::UNKNOWN,
    32         "order should be either \"NCHW\" or \"NHWC\".");
    34       CAFFE_ENFORCE_EQ(OutputSize(), 3);
    38   bool RunOnDevice()
 override {
    39     const auto& X = 
Input(INPUT);
    40     const auto& gamma = 
Input(GAMMA);
    41     const auto& beta = 
Input(BETA);
    42     const int ndim = X.dim();
    43     const int N = X.dim32(0);
    44     const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
    45     const int HxW = X.numel() / (N * C);
    46     CAFFE_ENFORCE_EQ(C % group_, 0);
    47     CAFFE_ENFORCE_EQ(gamma.numel(), C);
    48     CAFFE_ENFORCE_EQ(beta.numel(), C);
    52     auto* Y = Output(OUTPUT, X.sizes(), at::dtype<T>());
    54     T* rsig_data = 
nullptr;
    55     if (OutputSize() == 3) {
    56       auto* mu = Output(MU, {N, G}, at::dtype<T>());
    57       auto* rsig = Output(INV_SIGMA, {N, G}, at::dtype<T>());
    58       mu_data = mu->template mutable_data<T>();
    59       rsig_data = rsig->template mutable_data<T>();
    62           &mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
    64           &rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
    65       mu_data = mu_.template mutable_data<T>();
    66       rsig_data = rsig_.template mutable_data<T>();
    68     return RunOnDeviceImpl(
    74         gamma.template data<T>(),
    75         beta.template data<T>(),
    76         Y->template mutable_data<T>(),
    95         &scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    97         &bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
    98     T* scale_data = scale_.template mutable_data<T>();
    99     T* bias_data = bias_.template mutable_data<T>();
   100     if (order_ == StorageOrder::NCHW) {
   101       const std::array<int, 2> X_dims = {N * G, D * HxW};
   102       const std::array<int, 2> Y_dims = {N * G, 1};
   103       math::Moments<T, Context>(
   104           2, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
   105       math::InvStd<T, Context>(
   106           N * G, 
static_cast<T>(epsilon_), rsig, rsig, &context_);
   107       ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
   108       GroupNormForwardNCHW(N, C, HxW, X, scale_data, bias_data, Y);
   110       const std::array<int, 4> X_dims = {N, HxW, G, D};
   111       const std::array<int, 4> Y_dims = {N, 1, G, 1};
   112       math::Moments<T, Context>(
   113           4, X_dims.data(), Y_dims.data(), X, mu, rsig, &context_);
   114       math::InvStd<T, Context>(
   115           N * G, 
static_cast<T>(epsilon_), rsig, rsig, &context_);
   116       ComputeFusedParams(N, G, D, mu, rsig, gamma, beta, scale_data, bias_data);
   117       GroupNormForwardNHWC(N, C, HxW, X, scale_data, bias_data, Y);
   122   void ComputeFusedParams(
   133     ConstEigenArrayMap<float> gamma_arr(gamma, D, G);
   134     ConstEigenArrayMap<float> beta_arr(beta, D, G);
   135     for (
int i = 0; i < N; ++i) {
   136       EigenArrayMap<T> scale_arr(scale + i * C, D, G);
   137       scale_arr = gamma_arr.rowwise() *
   138           ConstEigenVectorArrayMap<T>(rsig + i * G, G).transpose();
   139       EigenArrayMap<T>(bias + i * C, D, G) = beta_arr -
   140           scale_arr.rowwise() *
   141               ConstEigenVectorArrayMap<T>(mu + i * G, G).transpose();
   145   void GroupNormForwardNCHW(
   153     EigenArrayMap<float>(Y, HxW, N * C) =
   154         (ConstEigenArrayMap<float>(X, HxW, N * C).rowwise() *
   155          ConstEigenVectorArrayMap<float>(scale, N * C).transpose())
   157         ConstEigenVectorArrayMap<float>(bias, N * C).transpose();
   160   void GroupNormForwardNHWC(
   168     const int stride = HxW * C;
   169     for (
int i = 0; i < N; ++i) {
   170       EigenArrayMap<float>(Y + i * stride, C, HxW) =
   171           (ConstEigenArrayMap<float>(X + i * stride, C, HxW).colwise() *
   172            ConstEigenVectorArrayMap<float>(scale + i * C, C))
   174           ConstEigenVectorArrayMap<float>(bias + i * C, C);
   179   const float epsilon_;
   180   const StorageOrder order_;
   190   INPUT_TAGS(INPUT, GAMMA, BETA);
   191   OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA);
   194 template <
typename T, 
class Context>
   197   USE_OPERATOR_CONTEXT_FUNCTIONS;
   199   template <
class... Args>
   202         OP_SINGLE_ARG(
int, 
"group", group_, 32),
   203         order_(StringToStorageOrder(
   204             this->
template GetSingleArgument<std::string>(
"order", 
"NCHW"))) {
   207         StorageOrder::UNKNOWN,
   208         "order should be either \"NCHW\" or \"NHWC\".");
   211   bool RunOnDevice()
 override {
   212     const auto& dY = 
Input(OUTPUT_GRAD);
   213     const auto& X = 
Input(INPUT);
   214     const auto& gamma = 
Input(GAMMA);
   215     const auto& beta = 
Input(BETA);
   216     const auto& mu = 
Input(MU);
   217     const auto& rsig = 
Input(INV_SIGMA);
   218     const int ndim = X.dim();
   219     const int N = X.dim32(0);
   220     const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
   221     const int HxW = X.numel() / (N * C);
   222     CAFFE_ENFORCE_EQ(C % group_, 0);
   223     CAFFE_ENFORCE_EQ(gamma.numel(), C);
   224     CAFFE_ENFORCE_EQ(beta.numel(), C);
   225     const int G = group_;
   228     auto* dX = Output(INPUT_GRAD, X.sizes(), at::dtype<T>());
   229     auto* dgamma = Output(GAMMA_GRAD, gamma.sizes(), at::dtype<T>());
   230     auto* dbeta = Output(BETA_GRAD, beta.sizes(), at::dtype<T>());
   231     return RunOnDeviceImpl(
   236         dY.template data<T>(),
   237         X.template data<T>(),
   238         mu.template data<T>(),
   239         rsig.template data<T>(),
   240         gamma.template data<T>(),
   241         dX->template mutable_data<T>(),
   242         dgamma->template mutable_data<T>(),
   243         dbeta->template mutable_data<T>());
   247   bool RunOnDeviceImpl(
   262   const StorageOrder order_;
   269   INPUT_TAGS(OUTPUT_GRAD, INPUT, GAMMA, BETA, MU, INV_SIGMA);
   270   OUTPUT_TAGS(INPUT_GRAD, GAMMA_GRAD, BETA_GRAD);
   275 #endif // CAFFE2_OPERATORS_GROUP_NORM_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...