Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.h
1 #ifndef CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
2 #define CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
3 
4 #include "caffe2/core/context.h"
5 #include "caffe2/core/operator.h"
6 #include "caffe2/core/types.h"
7 #include "caffe2/utils/math.h"
8 
9 namespace caffe2 {
10 
11 namespace {
12 inline int GetDimFromOrderString(const string& str) {
13  auto order = StringToStorageOrder(str);
14  switch (order) {
15  case StorageOrder::NHWC:
16  return 3;
17  case StorageOrder::NCHW:
18  return 1;
19  default:
20  CAFFE_THROW("Unsupported storage order: ", str);
21  return -1;
22  }
23 }
24 } // namespace
25 
26 template <class Context>
27 class SplitOp final : public Operator<Context> {
28  public:
29  static const int kSplitOpInputSize = 2;
30 
31  USE_OPERATOR_CONTEXT_FUNCTIONS;
32  template <class... Args>
33  explicit SplitOp(Args&&... args)
34  : Operator<Context>(std::forward<Args>(args)...),
35  split_(this->template GetRepeatedArgument<int>("split")) {
36  CAFFE_ENFORCE(
37  !(OperatorBase::HasArgument("axis") &&
38  OperatorBase::HasArgument("order")),
39  "You shouldn't specify both the dim to split, and the order "
40  "in the case of 4-D images.");
41  if (OperatorBase::HasArgument("axis")) {
42  axis_ = this->template GetSingleArgument<int>("axis", -1);
43  // only exists for computing the gradient of a Concat with 'add_axis'
44  add_axis_ = this->template GetSingleArgument<int>("add_axis", 0);
45  } else {
46  axis_ = GetDimFromOrderString(
47  this->template GetSingleArgument<string>("order", "NCHW"));
48  add_axis_ = 0;
49  }
50  }
51 
52  bool RunOnDevice() override;
53 
54  protected:
55  int axis_;
56  int add_axis_;
57  vector<int> split_;
58  // Input: X, optionally split
59  // The split tensor is stored in CPU.
60 };
61 
62 template <class Context>
63 class SplitByLengthsOp final : public Operator<Context> {
64  public:
65  USE_OPERATOR_CONTEXT_FUNCTIONS;
66  template <class... Args>
67  explicit SplitByLengthsOp(Args&&... args)
68  : Operator<Context>(std::forward<Args>(args)...) {
69  CAFFE_ENFORCE(
70  !(OperatorBase::HasArgument("axis") &&
71  OperatorBase::HasArgument("order")),
72  "You shouldn't specify both the dim to split, and the order "
73  "in the case of 4-D images.");
74  if (OperatorBase::HasArgument("axis")) {
75  axis_ = this->template GetSingleArgument<int>("axis", 0);
76  } else {
77  axis_ = GetDimFromOrderString(
78  this->template GetSingleArgument<string>("order", "NCHW"));
79  }
80  }
81 
82  bool RunOnDevice() override;
83 
84  protected:
85  int axis_;
86  Tensor inclusive_scan_buffer_{Context::GetDeviceType()};
87  Tensor inclusive_scan_length_buffer_{Context::GetDeviceType()};
88  // Input: X, optionally split
89  // The split tensor is stored in CPU.
90 };
91 
92 template <class Context>
93 class ConcatOp final : public Operator<Context> {
94  public:
95  USE_OPERATOR_CONTEXT_FUNCTIONS;
96  template <class... Args>
97  explicit ConcatOp(Args&&... args)
98  : Operator<Context>(std::forward<Args>(args)...) {
99  CAFFE_ENFORCE(
100  !(OperatorBase::HasArgument("axis") &&
101  OperatorBase::HasArgument("order")),
102  "You shouldn't specify both the dim to concat, and the order "
103  "in the case of 4-D images.");
104  if (OperatorBase::HasArgument("axis")) {
105  axis_ = this->template GetSingleArgument<int>("axis", -1);
106  add_axis_ = this->template GetSingleArgument<int>("add_axis", 0);
107  } else {
108  axis_ = GetDimFromOrderString(
109  this->template GetSingleArgument<string>("order", "NCHW"));
110  add_axis_ = 0;
111  }
112  }
113 
114  bool RunOnDevice() override;
115 
116  protected:
117  int axis_;
118  int add_axis_;
119  // Input: a number of tensors. Output: Y, split
120  // The split are stored in CPU.
121 };
122 
123 // Implementations
124 template <class Context>
126  auto& input = Input(0);
127  int canonical_axis = input.canonical_axis_index(axis_);
128  CAFFE_ENFORCE_LT(
129  canonical_axis, input.dim(), "Axis not in input ndim range.");
130  const int input_channels = input.dim32(canonical_axis);
131  const int* axis_data;
132  vector<int> equal_split;
133  if (InputSize() == kSplitOpInputSize) {
134  // We obtain split from the input tensor.
135  CAFFE_ENFORCE_EQ(
136  split_.size(),
137  0,
138  "If you set split with an input blob, do not pass in "
139  "split in the argument.");
140  auto& split_tensor = this->template Input<Tensor>(1, CPU);
141  CAFFE_ENFORCE_EQ(split_tensor.numel(), OutputSize());
142  axis_data = split_tensor.template data<int>();
143  } else if (split_.size() == 0) {
144  CAFFE_ENFORCE_EQ(
145  input_channels % OutputSize(),
146  0,
147  "If you did not specify split explicitly, the number of "
148  "input channels should be divisible by the output size.");
149  equal_split.resize(OutputSize(), input_channels / OutputSize());
150  axis_data = equal_split.data();
151  } else {
152  // We obtain split from the parameters.
153  CAFFE_ENFORCE_EQ(
154  split_.size(),
155  OutputSize(),
156  "The number of splits specified should be equal to the "
157  "number of outputs.");
158  axis_data = split_.data();
159  }
160 
161  CAFFE_ENFORCE_EQ(
162  add_axis_ ? OutputSize()
163  : std::accumulate(axis_data, axis_data + OutputSize(), 0),
164  input_channels,
165  "Sum of split dimensions do not match: should be ",
166  input_channels);
167  vector<int64_t> output_dims(input.sizes().vec());
168  int before = 1, after = 1;
169  for (int i = 0; i < canonical_axis; ++i) {
170  before *= input.dim32(i);
171  }
172  for (int i = canonical_axis + 1; i < input.dim(); ++i) {
173  after *= input.dim32(i);
174  }
175  if (add_axis_) {
176  output_dims.erase(output_dims.begin() + canonical_axis);
177  }
178  size_t input_offset = 0;
179  for (int i = 0; i < OutputSize(); ++i) {
180  auto* output = Output(i);
181  auto axis_dim = add_axis_ ? 1 : axis_data[i];
182  if (!add_axis_) {
183  output_dims[canonical_axis] = axis_data[i];
184  }
185  output->Resize(output_dims);
186  math::CopyMatrix<Context>(
187  input.itemsize(),
188  before,
189  axis_dim * after,
190  static_cast<const char*>(input.raw_data()) + input_offset,
191  input.dim32(canonical_axis) * after,
192  output->raw_mutable_data(input.dtype()),
193  axis_dim * after,
194  &context_,
195  input.dtype().copy());
196  input_offset += axis_dim * after * input.itemsize();
197  }
198  return true;
199 }
200 
201 // Implementations
202 template <class Context>
204  auto& input = Input(0);
205  auto& length = this->template Input<Tensor>(1, CPU);
206  auto length_length = length.numel();
207  CAFFE_ENFORCE_EQ(
208  length_length % OutputSize(),
209  0,
210  "len(Lengths) should be divisible by OutputSize().");
211  int canonical_axis = input.canonical_axis_index(axis_);
212  CAFFE_ENFORCE_LT(
213  canonical_axis, input.dim(), "Axis not in input ndim range.");
214  const int input_channels = input.dim32(canonical_axis);
215  const auto* axis_data = length.template data<int>();
216  CAFFE_ENFORCE_EQ(
217  std::accumulate(axis_data, axis_data + length.numel(), 0),
218  input_channels,
219  "Sum of split dimensions do not match: should be ",
220  input_channels);
221  vector<int64_t> output_dims(input.sizes().vec());
222  int before = input.size_to_dim(canonical_axis);
223  int after = input.size_from_dim(canonical_axis + 1);
224  size_t input_offset = 0;
225  for (int i = 0; i < OutputSize(); ++i) {
226  auto* output = Output(i);
227  const auto* axis_offset = axis_data + length_length / OutputSize() * i;
228  auto axis_dim = std::accumulate(
229  axis_offset, axis_offset + length_length / OutputSize(), 0);
230  output_dims[canonical_axis] = axis_dim;
231  output->Resize(output_dims);
232  math::CopyMatrix<Context>(
233  input.itemsize(),
234  before,
235  axis_dim * after,
236  static_cast<const char*>(input.raw_data()) + input_offset,
237  input.dim32(canonical_axis) * after,
238  output->raw_mutable_data(input.dtype()),
239  axis_dim * after,
240  &context_,
241  input.dtype().copy());
242  input_offset += axis_dim * after * input.itemsize();
243  }
244  return true;
245 }
246 
247 template <class Context>
249  auto* output = Output(0);
250 
251  // We can override default options(Context::GetDeviceType())
252  // by explictly passing in device type we want
253  Tensor* split = Output(
254  1, std::vector<int64_t>(1, InputSize()), at::dtype<int>().device(CPU));
255  int* axis_data = split->template mutable_data<int>();
256  auto& input_zero = Input(0);
257  int adj_size = input_zero.dim() + (add_axis_ ? 1 : 0);
258  int canonical_axis = canonical_axis_index_(axis_, adj_size);
259  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
260  for (int i = 1; i < InputSize(); ++i) {
261  CAFFE_ENFORCE(
262  Input(i).dtype() == input_zero.dtype(),
263  "All inputs must have the same type, expected: ",
264  input_zero.dtype().name(),
265  " but got: ",
266  Input(i).dtype().name(),
267  " for input: ",
268  i);
269  }
270 
271  int before = 1, after = 1;
272  vector<int64_t> output_dims(input_zero.sizes().vec());
273  for (int i = 0; i < input_zero.dim(); ++i) {
274  if (i == canonical_axis && !add_axis_) {
275  continue;
276  }
277  int dim = input_zero.dim32(i);
278  if (i < canonical_axis) {
279  before *= dim;
280  } else { // i > canonical_axis || i == canonical_axis && add_axis_
281  after *= dim;
282  }
283  // check the input dims are compatible.
284  for (int j = 1; j < InputSize(); ++j) {
285  int dim_j = Input(j).dim32(i);
286  CAFFE_ENFORCE(
287  dim == dim_j,
288  "Expect dimension = ",
289  dim,
290  " got ",
291  dim_j,
292  " at axis = ",
293  i,
294  " for input: ",
295  j,
296  ". The input tensors can only have different dimensions "
297  "when arg 'add_axis' = 0 and along the axis = ",
298  canonical_axis,
299  " <",
300  Input(0).sizes(),
301  "> vs <",
302  Input(j).sizes(),
303  ">.");
304  }
305  }
306 
307  int output_channels = 0;
308  for (int i = 0; i < InputSize(); ++i) {
309  axis_data[i] = add_axis_ ? 1 : Input(i).dim32(canonical_axis);
310  output_channels += axis_data[i];
311  }
312  if (add_axis_) {
313  output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
314  } else {
315  output_dims[canonical_axis] = output_channels;
316  }
317  output->Resize(output_dims);
318  size_t output_offset = 0;
319  for (int i = 0; i < InputSize(); ++i) {
320  auto& input = Input(i);
321  auto axis_dim = add_axis_ ? 1 : input.dim32(canonical_axis);
322  math::CopyMatrix<Context>(
323  input.itemsize(),
324  before,
325  axis_dim * after,
326  input.raw_data(),
327  axis_dim * after,
328  static_cast<char*>(output->raw_mutable_data(input_zero.dtype())) +
329  output_offset,
330  output_channels * after,
331  &context_,
332  input_zero.dtype().copy());
333  output_offset += axis_dim * after * input.itemsize();
334  }
335  return true;
336 }
337 
338 OpSchema::Cost CostInferenceForConcat(
339  const OperatorDef& def,
340  const std::vector<TensorShape>& in);
341 
342 std::vector<TensorShape> TensorInferenceForConcat(
343  const OperatorDef& def,
344  const std::vector<TensorShape>& in);
345 
346 } // namespace caffe2
347 
348 #endif // CAFFE2_OPERATORS_CONCAT_SPLIT_OP_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position &#39;idx&#39; for this operator. ...
Definition: operator.h:702
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
bool HasArgument(const string &name) const
Checks if the operator has an argument of the given name.
Definition: operator.h:70