1 #ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_ 2 #define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/utils/math.h" 12 template <
class Context>
15 USE_OPERATOR_CONTEXT_FUNCTIONS;
17 template <
class... Args>
20 order_(StringToStorageOrder(
21 this->
template GetSingleArgument<std::string>(
"order",
"NCHW"))) {
22 CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
25 bool RunOnDevice()
override {
30 bool DoRunWithType() {
31 const auto& X =
Input(0);
32 const int ndim = X.dim();
33 const int N = X.dim32(0);
34 const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
35 const int HxW = X.numel() / (N * C);
36 auto* sum = Output(0, {C}, at::dtype<T>());
37 auto* sumsq = Output(1, {C}, at::dtype<T>());
38 const T* X_data = X.template data<T>();
39 T* sum_data = sum->template mutable_data<T>();
40 T* sumsq_data = sumsq->template mutable_data<T>();
41 return order_ == StorageOrder::NCHW
42 ? ComputeChannelStatsNCHW<T>(N, C, HxW, X_data, sum_data, sumsq_data)
43 : ComputeChannelStatsNHWC<T>(N, C, HxW, X_data, sum_data, sumsq_data);
49 ComputeChannelStatsNCHW(
int N,
int C,
int HxW,
const T* X,
T* sum,
T* sumsq);
53 ComputeChannelStatsNHWC(
int N,
int C,
int HxW,
const T* X,
T* sum,
T* sumsq);
55 const StorageOrder order_;
60 #endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...