1 #ifndef CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_     2 #define CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_     4 #include "caffe2/core/context.h"     5 #include "caffe2/core/logging.h"     6 #include "caffe2/core/operator.h"    10 template <
typename T, 
class Context>
    13   USE_OPERATOR_CONTEXT_FUNCTIONS;
    15   template <
class... Args>
    18         order_(StringToStorageOrder(
    19             this->
template GetSingleArgument<std::string>(
"order", 
"NCHW"))) {
    20     CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
    23   bool RunOnDevice()
 override {
    24     const auto& X = 
Input(0);
    26     const int ndim = X.dim();
    27     const int N = X.dim32(0);
    28     const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
    29     const int HxW = X.numel() / (N * C);
    30     auto* mu = Output(0, {C}, at::dtype<T>());
    31     auto* var = Output(1, {C}, at::dtype<T>());
    32     const T* X_data = X.template data<T>();
    33     T* mu_data = mu->template mutable_data<T>();
    34     T* var_data = var->template mutable_data<T>();
    35     return order_ == StorageOrder::NCHW
    36         ? ComputeBatchMomentsNCHW(N, C, HxW, X_data, mu_data, var_data)
    37         : ComputeBatchMomentsNHWC(N, C, HxW, X_data, mu_data, var_data);
    41   bool ComputeBatchMomentsNCHW(
    49   bool ComputeBatchMomentsNHWC(
    57   const StorageOrder order_;
    60 template <
typename T, 
class Context>
    63   USE_OPERATOR_CONTEXT_FUNCTIONS;
    65   template <
class... Args>
    68         order_(StringToStorageOrder(
    69             this->
template GetSingleArgument<std::string>(
"order", 
"NCHW"))) {
    70     CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
    73   bool RunOnDevice()
 override {
    74     const auto& dmu = 
Input(0);
    75     const auto& dvar = 
Input(1);
    76     const auto& X = 
Input(2);
    78     const int ndim = X.dim();
    79     const int N = X.dim32(0);
    80     const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
    81     const int HxW = X.numel() / (N * C);
    82     auto* dX = Output(0, X.sizes(), at::dtype<T>());
    83     const T* dmu_data = dmu.template data<T>();
    84     const T* dvar_data = dvar.template data<T>();
    85     const T* X_data = X.template data<T>();
    86     T* dX_data = dX->template mutable_data<T>();
    87     return order_ == StorageOrder::NCHW
    88         ? ComputeBatchMomentsGradientNCHW(
    89               N, C, HxW, dmu_data, dvar_data, X_data, dX_data)
    90         : ComputeBatchMomentsGradientNHWC(
    91               N, C, HxW, dmu_data, dvar_data, X_data, dX_data);
    95   bool ComputeBatchMomentsGradientNCHW(
   104   bool ComputeBatchMomentsGradientNHWC(
   113   const StorageOrder order_;
   118 #endif // CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_ 
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
 
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...