Caffe2 - C++ API
A deep learning, cross platform ML framework
tile_op.h
1 #ifndef CAFFE2_OPERATORS_TILE_OP_H_
2 #define CAFFE2_OPERATORS_TILE_OP_H_
3 
4 #include <array>
5 #include <string>
6 #include <type_traits>
7 #include <vector>
8 
9 #include "caffe2/core/common_omp.h"
10 #include "caffe2/core/context.h"
11 #include "caffe2/core/logging.h"
12 #include "caffe2/core/operator.h"
13 #include "caffe2/utils/eigen_utils.h"
14 #include "caffe2/utils/math.h"
15 
16 namespace caffe2 {
17 
18 // Copy a Blob n times along a specified axis.
19 template <class Context>
20 class TileOp final : public Operator<Context> {
21  public:
22  USE_OPERATOR_CONTEXT_FUNCTIONS;
23 
24  template <class... Args>
25  explicit TileOp(Args&&... args)
26  : Operator<Context>(std::forward<Args>(args)...),
27  OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1),
28  OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {}
29 
30  bool RunOnDevice() override {
31  return DispatchHelper<
33  call(this, Input(0));
34  }
35 
36  template <typename T>
37  bool DoRunWithType() {
38  if (InputSize() > 1) {
39  // We potentially have tiles and/or axis specified as inputs
40  // as well. We will check for them in that order. In other words:
41  // InputSize() == 2: tiles is specified
42  // InputSize() == 3: tiles is specified and axis.
43  // Anything specified as input will override the arguments
44  CAFFE_ENFORCE(
45  Input(1).dim() == 1 && Input(1).numel() == 1,
46  "Input `tiles` should be a vector of size 1.");
47  tiles_ = GetArgFromTensor(Input(1));
48 
49  // Because of a bug in original code, temporarily adds this part to keep
50  // backward compatibility.
51  // TODO(yangxm): Remove this part when prod runtime upgraded with fixed
52  // model config.
53  if (Input(1).template IsType<std::int64_t>()) {
54  axis_ = 0;
55  }
56 
57  if (InputSize() > 2) {
58  CAFFE_ENFORCE(
59  Input(2).dim() == 1 && Input(2).numel() == 1,
60  "Input `axis` should be a vector of size 1.");
61  axis_ = GetArgFromTensor(Input(2));
62  } else {
63  CAFFE_ENFORCE(
65  "Argument `axis` is missing and was not specified as input.");
66  }
67  } else {
68  CAFFE_ENFORCE(
70  "Argument `tiles` is missing and was not specified as input.");
71  CAFFE_ENFORCE(
73  "Argument `axis` is missing and was not specified as input.");
74  }
75 
76  const auto& X = Input(0);
77  auto* Y = Output(0);
78  const int axis = X.canonical_axis_index(axis_);
79 
80  // reshape output to be input tiled along the axis
81  std::vector<std::int64_t> Y_dims = X.sizes().vec();
82  Y_dims[axis] *= tiles_;
83  Y->Resize(Y_dims);
84 
85  // size up to (and not including) axis
86  const int outer_size = X.size_to_dim(axis);
87  // size from axis up
88  const int inner_size = X.size_from_dim(axis);
89 
90  const T* X_data = X.template data<T>();
91  T* Y_data = Y->template mutable_data<T>();
92  return DoTile<T>(outer_size, inner_size, X_data, Y_data);
93  }
94 
95  private:
96  std::int32_t GetArgFromTensor(const Tensor& tensor) {
97  CAFFE_ENFORCE(
98  tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>());
99  std::int32_t val = -1;
100  if (tensor.IsType<std::int32_t>()) {
101  context_.template CopyToCPU<std::int32_t>(
102  1, tensor.data<std::int32_t>(), &val);
103  } else if (tensor.IsType<std::int64_t>()) {
104  std::int64_t val_int64;
105  context_.template CopyToCPU<std::int64_t>(
106  1, tensor.data<std::int64_t>(), &val_int64);
107  val = static_cast<std::int32_t>(val_int64);
108  }
109  return val;
110  }
111 
112  template <typename T>
113  bool DoTile(const int outer_size, const int inner_size, const T* X, T* Y) {
114  if (inner_size == 1) {
115  EigenArrayMap<T> Y_arr(Y, tiles_, outer_size);
116  for (int i = 0; i < outer_size; ++i) {
117  Y_arr.col(i) = X[i];
118  }
119  } else {
120  ConstEigenArrayMap<T> X_arr(X, inner_size, outer_size);
121  for (int i = 0; i < outer_size; ++i) {
122  EigenArrayMap<T>(Y + i * tiles_ * inner_size, inner_size, tiles_)
123  .colwise() = X_arr.col(i);
124  }
125  }
126  return true;
127  }
128 
129  std::int32_t tiles_;
130  std::int32_t axis_;
131 };
132 
133 template <class Context>
134 class TileGradientOp final : public Operator<Context> {
135  public:
136  USE_OPERATOR_CONTEXT_FUNCTIONS;
137 
138  template <class... Args>
139  explicit TileGradientOp(Args&&... args)
140  : Operator<Context>(std::forward<Args>(args)...),
141  OP_SINGLE_ARG(std::int32_t, "tiles", tiles_, 1),
142  OP_SINGLE_ARG(std::int32_t, "axis", axis_, 0) {}
143 
144  bool RunOnDevice() override {
145  return DispatchHelper<
147  call(this, Input(0));
148  }
149 
150  template <typename T>
151  bool DoRunWithType() {
152  if (InputSize() > 1) {
153  // We potentially have tiles and/or axis specified as inputs
154  // as well. We will check for them in that order. In other words:
155  // InputSize() == 2: tiles is specified
156  // InputSize() == 3: tiles is specified and axis.
157  // Anything specified as input will override the arguments
158  CAFFE_ENFORCE(
159  Input(1).dim() == 1 && Input(1).numel() == 1,
160  "Input `tiles` should be a vector of size 1.");
161  tiles_ = GetArgFromTensor(Input(1));
162  if (InputSize() > 2) {
163  CAFFE_ENFORCE(
164  Input(2).dim() == 1 && Input(2).numel() == 1,
165  "Input `axis` should be a vector of size 1.");
166  axis_ = GetArgFromTensor(Input(2));
167  } else {
168  CAFFE_ENFORCE(
170  "Argument `axis` is missing and was not specified as input.");
171  }
172  } else {
173  CAFFE_ENFORCE(
174  OperatorBase::HasArgument("tiles"),
175  "Argument `tiles` is missing and was not specified as input.");
176  CAFFE_ENFORCE(
178  "Argument `axis` is missing and was not specified as input.");
179  }
180 
181  const auto& dY = Input(0);
182  auto* dX = Output(0);
183  const int axis = dY.canonical_axis_index(axis_);
184 
185  // reshape output to be input "untiled" along the axis
186  std::vector<std::int64_t> X_dims = dY.sizes().vec();
187  CAFFE_ENFORCE_EQ(X_dims[axis] % tiles_, 0);
188  X_dims[axis] /= tiles_;
189  dX->Resize(X_dims);
190 
191  // size up to (and not including) axis
192  const int outer_size = dX->size_to_dim(axis);
193  // size from axis up
194  const int inner_size = dX->size_from_dim(axis);
195 
206  const T* dY_data = dY.template data<T>();
207  T* dX_data = dX->template mutable_data<T>();
208  return DoTileGradient<T>(outer_size, inner_size, dY_data, dX_data);
209  }
210 
211  private:
212  std::int32_t GetArgFromTensor(const Tensor& tensor) {
213  CAFFE_ENFORCE(
214  tensor.IsType<std::int32_t>() || tensor.IsType<std::int64_t>());
215  std::int32_t val = -1;
216  if (tensor.IsType<std::int32_t>()) {
217  context_.template CopyToCPU<std::int32_t>(
218  1, tensor.data<std::int32_t>(), &val);
219  } else if (tensor.IsType<std::int64_t>()) {
220  std::int64_t val_int64;
221  context_.template CopyToCPU<std::int64_t>(
222  1, tensor.data<std::int64_t>(), &val_int64);
223  val = static_cast<std::int32_t>(val_int64);
224  }
225  return val;
226  }
227 
228  template <typename T>
229  bool DoTileGradient(
230  const int outer_size,
231  const int inner_size,
232  const T* dY,
233  T* dX) {
234  if (inner_size == 1) {
235  const std::array<int, 2> dY_dims = {outer_size, tiles_};
236  const std::array<int, 2> dX_dims = {outer_size, 1};
237  math::ReduceSum<T, Context>(
238  2, dY_dims.data(), dX_dims.data(), T(1), dY, dX, &context_);
239  } else {
240  math::CopyMatrix<T, Context>(
241  outer_size,
242  inner_size,
243  dY,
244  inner_size * tiles_,
245  dX,
246  inner_size,
247  &context_);
248  for (int i = 0; i < outer_size; ++i) {
249  const T* dY_ptr = dY + i * tiles_ * inner_size;
250  T* dX_ptr = dX + i * inner_size;
251  for (int j = 1; j < tiles_; ++j) {
252  math::Add<T, Context>(
253  inner_size, dX_ptr, dY_ptr + j * inner_size, dX_ptr, &context_);
254  }
255  }
256  }
257  return true;
258  }
259 
260  std::int32_t tiles_;
261  std::int32_t axis_;
262 
263  Tensor ones_;
264 };
265 
266 } // namespace caffe2
267 
268 #endif // CAFFE2_OPERATORS_TILE_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70