Caffe2 - C++ API
A deep learning, cross platform ML framework
channel_stats_op.h
1 #ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
2 #define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
3 
4 #include <string>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/utils/math.h"
9 
10 namespace caffe2 {
11 
12 template <class Context>
13 class ChannelStatsOp final : public Operator<Context> {
14  public:
15  USE_OPERATOR_CONTEXT_FUNCTIONS;
16 
17  template <class... Args>
18  explicit ChannelStatsOp(Args&&... args)
19  : Operator<Context>(std::forward<Args>(args)...),
20  order_(StringToStorageOrder(
21  this->template GetSingleArgument<std::string>("order", "NCHW"))) {
22  CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
23  }
24 
25  bool RunOnDevice() override {
26  return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
27  }
28 
29  template <typename T>
30  bool DoRunWithType() {
31  const auto& X = Input(0);
32  const int ndim = X.dim();
33  const int N = X.dim32(0);
34  const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
35  const int HxW = X.numel() / (N * C);
36  auto* sum = Output(0, {C}, at::dtype<T>());
37  auto* sumsq = Output(1, {C}, at::dtype<T>());
38  const T* X_data = X.template data<T>();
39  T* sum_data = sum->template mutable_data<T>();
40  T* sumsq_data = sumsq->template mutable_data<T>();
41  return order_ == StorageOrder::NCHW
42  ? ComputeChannelStatsNCHW<T>(N, C, HxW, X_data, sum_data, sumsq_data)
43  : ComputeChannelStatsNHWC<T>(N, C, HxW, X_data, sum_data, sumsq_data);
44  }
45 
46  private:
47  template <typename T>
48  bool
49  ComputeChannelStatsNCHW(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
50 
51  template <typename T>
52  bool
53  ComputeChannelStatsNHWC(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
54 
55  const StorageOrder order_;
56 };
57 
58 } // namespace caffe2
59 
60 #endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64