1 #ifndef CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_ 2 #define CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 10 template <
typename T,
class Context>
13 USE_OPERATOR_CONTEXT_FUNCTIONS;
15 template <
class... Args>
18 order_(StringToStorageOrder(
19 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))) {
20 CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
23 bool RunOnDevice()
override {
24 const auto& X =
Input(0);
26 const int ndim = X.dim();
27 const int N = X.dim32(0);
28 const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
29 const int HxW = X.numel() / (N * C);
30 auto* mu = Output(0, {C}, at::dtype<T>());
31 auto* var = Output(1, {C}, at::dtype<T>());
32 const T* X_data = X.template data<T>();
33 T* mu_data = mu->template mutable_data<T>();
34 T* var_data = var->template mutable_data<T>();
35 return order_ == StorageOrder::NCHW
36 ? ComputeBatchMomentsNCHW(N, C, HxW, X_data, mu_data, var_data)
37 : ComputeBatchMomentsNHWC(N, C, HxW, X_data, mu_data, var_data);
41 bool ComputeBatchMomentsNCHW(
49 bool ComputeBatchMomentsNHWC(
57 const StorageOrder order_;
60 template <
typename T,
class Context>
63 USE_OPERATOR_CONTEXT_FUNCTIONS;
65 template <
class... Args>
68 order_(StringToStorageOrder(
69 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))) {
70 CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
73 bool RunOnDevice()
override {
74 const auto& dmu =
Input(0);
75 const auto& dvar =
Input(1);
76 const auto& X =
Input(2);
78 const int ndim = X.dim();
79 const int N = X.dim32(0);
80 const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
81 const int HxW = X.numel() / (N * C);
82 auto* dX = Output(0, X.sizes(), at::dtype<T>());
83 const T* dmu_data = dmu.template data<T>();
84 const T* dvar_data = dvar.template data<T>();
85 const T* X_data = X.template data<T>();
86 T* dX_data = dX->template mutable_data<T>();
87 return order_ == StorageOrder::NCHW
88 ? ComputeBatchMomentsGradientNCHW(
89 N, C, HxW, dmu_data, dvar_data, X_data, dX_data)
90 : ComputeBatchMomentsGradientNHWC(
91 N, C, HxW, dmu_data, dvar_data, X_data, dX_data);
95 bool ComputeBatchMomentsGradientNCHW(
104 bool ComputeBatchMomentsGradientNHWC(
113 const StorageOrder order_;
118 #endif // CAFFE2_OPERATORS_BATCH_MOMENTS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...