Caffe2 - C++ API
A deep learning, cross platform ML framework
reduce_front_back_max_ops.h
1 #ifndef CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_
2 #define CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/logging.h"
6 #include "caffe2/core/operator.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 template <typename T, class Context, bool FIRSTDIMS>
12 class MaxReduceDimsOp final : public Operator<Context> {
13  public:
14  template <class... Args>
15  explicit MaxReduceDimsOp(Args&&... args)
16  : Operator<Context>(std::forward<Args>(args)...),
17  num_reduce_dims_(
18  this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
19 
20  USE_OPERATOR_CONTEXT_FUNCTIONS;
21 
22  bool RunOnDevice() {
23  auto& X = Input(0);
24 
25  CAFFE_ENFORCE(
26  num_reduce_dims_ >= 0 && num_reduce_dims_ <= X.dim(),
27  "For N-dim input tensor, support num_reduce_dims in range [0, N].");
28 
29  const int rows = FIRSTDIMS ? X.size_to_dim(num_reduce_dims_)
30  : X.size_to_dim(X.dim() - num_reduce_dims_);
31  const int cols = FIRSTDIMS ? X.size_from_dim(num_reduce_dims_)
32  : X.size_from_dim(X.dim() - num_reduce_dims_);
33 
34  vector<int64_t> output_shape;
35  int start_index = FIRSTDIMS ? num_reduce_dims_ : 0;
36  int end_index =
37  FIRSTDIMS ? X.dim() : X.dim() - num_reduce_dims_;
38 
39  for (int i = start_index; i < end_index; ++i) {
40  output_shape.push_back(X.sizes()[i]);
41  }
42  auto* Y = Output(0, output_shape, at::dtype<float>());
43  float* out_data = Y->template mutable_data<float>();
44 
45  if (cols == 0 || rows == 0) {
46  math::Set(Y->numel(), static_cast<float>(0), out_data, &context_);
47  return true;
48  }
49 
50  const int32_t* lengths_data = nullptr;
51  if (InputSize() > 1) {
52  const auto& lengths = Input(1);
53  lengths_data = lengths.template data<int32_t>();
54  CAFFE_ENFORCE(
55  num_reduce_dims_ == 1,
56  "Given lengths input, the number of reduce dimensions should be one.");
57  const int batch_size = FIRSTDIMS ? cols : rows;
58  CAFFE_ENFORCE(
59  lengths.numel() == batch_size,
60  "The size of lengths vector doesn't match the batch size.");
61  }
62 
63  const float* data = X.template data<float>();
64  Compute(rows, cols, data, lengths_data, out_data);
65  return true;
66  }
67 
68  protected:
69  void Compute(
70  int rows,
71  int cols,
72  const float* data,
73  const int32_t* lengths_data,
74  float* out_data);
75 
76  int num_reduce_dims_;
77 };
78 
79 template <typename T, class Context, bool FIRSTDIMS>
80 class MaxReduceDimsGradientOp final : public Operator<Context> {
81  public:
82  template <class... Args>
83  explicit MaxReduceDimsGradientOp(Args&&... args)
84  : Operator<Context>(std::forward<Args>(args)...),
85  num_reduce_dims_(
86  this->template GetSingleArgument<int32_t>("num_reduce_dim", 1)) {}
87 
88  USE_OPERATOR_CONTEXT_FUNCTIONS;
89 
90  bool RunOnDevice() override {
91  auto& dY = Input(0);
92  auto& X = Input(1);
93  auto& Y = Input(2);
94 
95  auto* dX = Output(0, X.sizes(), at::dtype<float>());
96  const int rows = FIRSTDIMS ? X.size_to_dim(num_reduce_dims_)
97  : X.size_to_dim(X.dim() - num_reduce_dims_);
98  const int cols = FIRSTDIMS ? X.size_from_dim(num_reduce_dims_)
99  : X.size_from_dim(X.dim() - num_reduce_dims_);
100 
101  const float* dYdata = dY.template data<float>();
102  const float* Xdata = X.template data<float>();
103  const float* Ydata = Y.template data<float>();
104 
105  const int32_t* lengths_data = nullptr;
106  if (InputSize() > 3) {
107  const auto& lengths = Input(3);
108  lengths_data = lengths.template data<int32_t>();
109  CAFFE_ENFORCE(
110  num_reduce_dims_ == 1,
111  "Given lengths input, the number of reduce dimensions should be one.");
112  const int batch_size = FIRSTDIMS ? cols : rows;
113  CAFFE_ENFORCE(
114  lengths.numel() == batch_size,
115  "The size of lengths vector doesn't match the batch size.");
116  }
117 
118  float* dXdata = dX->template mutable_data<float>();
119  Compute(rows, cols, dYdata, Xdata, Ydata, lengths_data, dXdata);
120  return true;
121  }
122 
123  protected:
124  void Compute(
125  int rows,
126  int cols,
127  const float* dYdata,
128  const float* Xdata,
129  const float* Ydata,
130  const int32_t* lengths_data,
131  float* dXdata);
132 
133  int num_reduce_dims_;
134 };
135 
136 } // namespace caffe2
137 
138 #endif // CAFFE2_OPERATORS_REDUCE_FRONT_BACK_MAX_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13