Caffe2 - C++ API
A deep learning, cross platform ML framework
moments_op.h
1 #ifndef CAFFE2_OPERATORS_MOMENTS_OP_H_
2 #define CAFFE2_OPERATORS_MOMENTS_OP_H_
3 
4 #include <algorithm>
5 #include <vector>
6 
7 #include "caffe2/core/context.h"
8 #include "caffe2/core/operator.h"
9 #include "caffe2/utils/math.h"
10 
11 namespace caffe2 {
12 
13 template <typename T, class Context>
14 class MomentsOp final : public Operator<Context> {
15  public:
16  USE_OPERATOR_CONTEXT_FUNCTIONS;
17 
18  template <class... Args>
19  explicit MomentsOp(Args&&... args)
20  : Operator<Context>(std::forward<Args>(args)...),
21  axes_(this->template GetRepeatedArgument<int>("axes")),
22  OP_SINGLE_ARG(bool, "keepdims", keep_dims_, true) {}
23 
24  bool RunOnDevice() override {
25  const auto& X = Input(0);
26 
27  const int ndim = X.dim();
28  if (axes_.empty()) {
29  axes_.resize(ndim);
30  std::iota(axes_.begin(), axes_.end(), 0);
31  } else {
32  std::sort(axes_.begin(), axes_.end());
33  CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative.");
34  CAFFE_ENFORCE_LT(
35  axes_.back(),
36  ndim,
37  "Axes ids must be smaller than the dimensions of input.");
38  }
39  const std::vector<int> X_dims(X.sizes().cbegin(), X.sizes().cend());
40  std::vector<int> Y_dims = X_dims;
41  for (const int axis : axes_) {
42  Y_dims[axis] = 1;
43  }
44  std::vector<std::int64_t> output_dims;
45  output_dims.reserve(ndim);
46  std::size_t cur_axis = 0;
47  for (int i = 0; i < ndim; ++i) {
48  if (cur_axis < axes_.size() && i == axes_[cur_axis]) {
49  if (keep_dims_) {
50  output_dims.push_back(1);
51  }
52  ++cur_axis;
53  } else {
54  output_dims.push_back(X_dims[i]);
55  }
56  }
57  auto* mean = Output(0, output_dims, at::dtype<T>());
58  auto* var = Output(1, output_dims, at::dtype<T>());
59  math::Moments<float, Context>(
60  X_dims.size(),
61  X_dims.data(),
62  Y_dims.data(),
63  X.template data<T>(),
64  mean->template mutable_data<T>(),
65  var->template mutable_data<T>(),
66  &context_);
67  return true;
68  }
69 
70  private:
71  std::vector<int> axes_;
72  const int keep_dims_;
73 };
74 
75 template <typename T, class Context>
76 class MomentsGradientOp final : public Operator<Context> {
77  public:
78  USE_OPERATOR_CONTEXT_FUNCTIONS;
79 
80  template <class... Args>
81  explicit MomentsGradientOp(Args&&... args)
82  : Operator<Context>(std::forward<Args>(args)...),
83  axes_(this->template GetRepeatedArgument<int>("axes")) {}
84 
85  bool RunOnDevice() override {
86  const auto& dmean = Input(0);
87  const auto& dvariance = Input(1);
88  const auto& X = Input(2);
89  const auto& mean = Input(3);
90 
91  const int ndim = X.dim();
92  if (axes_.empty()) {
93  axes_.resize(ndim);
94  std::iota(axes_.begin(), axes_.end(), 0);
95  } else {
96  std::sort(axes_.begin(), axes_.end());
97  CAFFE_ENFORCE_GE(axes_.front(), 0, "Axes ids must be non-negative.");
98  CAFFE_ENFORCE_LT(
99  axes_.back(),
100  ndim,
101  "Axes ids must be smaller than the dimensions of input.");
102  }
103  const std::vector<int> dX_dims(X.sizes().cbegin(), X.sizes().cend());
104  std::vector<int> dY_dims = dX_dims;
105  for (const int axis : axes_) {
106  dY_dims[axis] = 1;
107  }
108  auto* dX = Output(0, X.sizes(), at::dtype<T>());
109  return Compute(
110  dY_dims,
111  dX_dims,
112  dmean.template data<T>(),
113  dvariance.template data<T>(),
114  X.template data<T>(),
115  mean.template data<T>(),
116  dX->template mutable_data<T>());
117  }
118 
119  private:
120  bool Compute(
121  const std::vector<int>& dY_dims,
122  const std::vector<int>& dX_dims,
123  const T* dmean_data,
124  const T* dvariance_data,
125  const T* X_data,
126  const T* mean_data,
127  T* dX_data);
128 
129  std::vector<int> axes_;
130 };
131 
132 } // namespace caffe2
133 
134 #endif // CAFFE2_OPERATORS_MOMENTS_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13