Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_split_op.cc
1 #include "caffe2/operators/concat_split_op.h"
2 
3 namespace caffe2 {
4 namespace {
5 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>> splitOpDevInfer(
6  const OperatorDef& def) {
7  auto op_device =
8  def.has_device_option() ? def.device_option() : DeviceOption();
9  vector<DeviceOption> in_dev(def.input_size(), op_device);
10  vector<DeviceOption> out_dev(def.output_size(), op_device);
11 
12  // If we obtain split from input tensor, then 2nd input's type is always CPU.
13  if (def.input_size() == SplitOp<CPUContext>::kSplitOpInputSize) {
14  CAFFE_ENFORCE_GT(in_dev.size(), 1);
15  in_dev[1] = DeviceOption();
16  }
17  return std::make_pair(in_dev, out_dev);
18 }
19 } // namespace.
20 
21 REGISTER_CPU_OPERATOR(Split, SplitOp<CPUContext>);
22 REGISTER_CPU_OPERATOR(SplitByLengths, SplitByLengthsOp<CPUContext>);
23 OPERATOR_SCHEMA(Split)
24  .NumInputs(1, 2)
25  .NumOutputs(1, INT_MAX)
26  .Input(0, "input", "(*Tensor*): tensor to split")
27  .Input(
28  1,
29  "split",
30  "(*Tensor`<int>`*): [OPTIONAL] list of output lengths (see also arg `split`)")
31  .Arg("axis", "(*int*): axis to split on")
32  .Arg("split", "(*Tuple(int)*): length of each output")
33  .Arg(
34  "order",
35  "(*string*): order of dimensions of input and output blobs; either \"NCHW\" or \"NHWC\"")
36  .Output(0, "[output_0, output_1, ...]", "(*Tensor*): output tensor")
37  .DeviceInferenceFunction(splitOpDevInfer)
38  .SetDoc(R"DOC(
39 Split an `input` tensor into a list of tensors, along the axis specified by the `axis` dimension. The lengths of the split can be specified using argument `split` or optional second input blob to the operator. Otherwise, the tensor is split to equal sized parts.
40 
41 Github Links:
42 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/concat_split_op.cc
43 
44 <details>
45 
46 <summary> <b>Example</b> </summary>
47 
48 **Code**
49 
50 ```
51 
52 workspace.ResetWorkspace()
53 
54 op = core.CreateOperator(
55  "Split",
56  ["input"],
57  ["output_0","output_1","output_2"],
58  split=(3,2,4),
59  axis=0
60 )
61 
62 workspace.FeedBlob("input", np.random.randint(10, size=(9)))
63 print("input:", workspace.FetchBlob("input"))
64 workspace.RunOperatorOnce(op)
65 print("output_0:", workspace.FetchBlob("output_0"))
66 print("output_1:", workspace.FetchBlob("output_1"))
67 print("output_2:", workspace.FetchBlob("output_2"))
68 
69 ```
70 
71 **Result**
72 
73 ```
74 
75 input: [2 2 6 6 6 0 5 7 4]
76 output_0: [2 2 6]
77 output_1: [6 6]
78 output_2: [0 5 7 4]
79 
80 ```
81 
82 </details>
83 
84 )DOC")
85  .InheritOnnxSchema();
86 
87 OPERATOR_SCHEMA(SplitByLengths)
88  .NumInputs(2)
89  .NumOutputs(1, INT_MAX)
90  .Input(0, "input", "The tensor to split")
91  .Input(1, "legnths", "The tensor `l_i` indicates the logic block of input.")
92  .Arg("axis", "Which axis to split on")
93  .Arg("order", "Either NHWC or NCWH, will split on C axis, defaults to NCHW")
94  .DeviceInferenceFunction([](const OperatorDef& def) {
95  auto op_device =
96  def.has_device_option() ? def.device_option() : DeviceOption();
97  vector<DeviceOption> in_dev(def.input_size(), op_device);
98  vector<DeviceOption> out_dev(def.output_size(), op_device);
99  // lengths input should be on CPU
100  in_dev[1] = DeviceOption();
101  return std::make_pair(in_dev, out_dev);
102  })
103  .SetDoc(R"DOC(
104 Split a tensor into a list of tensors, given a lengths input, along the specified
105 'axis'. If `K` outputs are provided, the op assumes `len(lengths) % K == 0`.
106 The `input` will be split into `K` parts. Each part of length
107 `sum(lengths[i*k:i*k+k))`)DOC");
108 
109 OpSchema::Cost CostInferenceForConcat(
110  const OperatorDef& def,
111  const vector<TensorShape>& in) {
112  ArgumentHelper helper(def);
113  const int axis = helper.HasArgument("axis")
114  ? helper.GetSingleArgument<int>("axis", -1)
115  : GetDimFromOrderString(
116  helper.GetSingleArgument<string>("order", "NCHW"));
117  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
118  const int canonical_axis = canonical_axis_index_(axis, in[0].dims_size());
119  CAFFE_ENFORCE_GT(in.size(), 0);
120  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
121  if (add_axis) {
122  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
123  } else {
124  for (int i = 1; i < in.size(); ++i) {
125  out_shape[canonical_axis] += in[i].dims(canonical_axis);
126  }
127  }
128  uint64_t nElemRead = 1;
129  for (int i = 0; i < in.size(); ++i) {
130  nElemRead += nElemFromDim(in[i]);
131  }
132  int size = 1;
133  for (auto& s : out_shape) {
134  size *= s;
135  }
136 
137  struct OpSchema::Cost cost;
138  cost.flops = 0;
139  cost.bytes_read = nElemRead * sizeof(in[0].data_type());
140  cost.bytes_written = size * sizeof(in[0].data_type());
141  cost.params_bytes = 0;
142  return cost;
143 }
144 
145 namespace {
146 std::pair<std::vector<DeviceOption>, std::vector<DeviceOption>>
147 concatOpDevInfer(const OperatorDef& def) {
148  auto op_device =
149  def.has_device_option() ? def.device_option() : DeviceOption();
150  vector<DeviceOption> in_dev(def.input_size(), op_device);
151  vector<DeviceOption> out_dev(def.output_size(), op_device);
152 
153  // 2nd output's type is always CPU irrespective of op's device option.
154  CAFFE_ENFORCE_GT(out_dev.size(), 1);
155  out_dev[1] = DeviceOption();
156  return std::make_pair(in_dev, out_dev);
157 }
158 } // namespace
159 
160 vector<TensorShape> TensorInferenceForConcat(
161  const OperatorDef& def,
162  const vector<TensorShape>& in) {
163  ArgumentHelper helper(def);
164  const int axis = helper.HasArgument("axis")
165  ? helper.GetSingleArgument<int>("axis", -1)
166  : GetDimFromOrderString(
167  helper.GetSingleArgument<string>("order", "NCHW"));
168  bool add_axis = helper.GetSingleArgument<int>("add_axis", 0) != 0;
169  int adj_size = in[0].dims_size() + (add_axis ? 1 : 0);
170  const int canonical_axis = canonical_axis_index_(axis, adj_size);
171  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
172  CAFFE_ENFORCE_GT(in.size(), 0);
173  vector<int> split_shape(1, in.size());
174  vector<int> out_shape(in[0].dims().begin(), in[0].dims().end());
175  if (add_axis) {
176  for (int i = 1; i < in.size(); ++i) {
177  CAFFE_ENFORCE_EQ(
178  in[0].dims().size(),
179  in[i].dims().size(),
180  "All inputs of Concat should have same dims when add_axis = 1. "
181  "Got different sizes for inputs 0 and ",
182  i);
183  for (int j = 0; j < in[0].dims().size(); ++j) {
184  CAFFE_ENFORCE_EQ(
185  in[0].dims(j),
186  in[i].dims(j),
187  "All inputs of Concat should have same dims when add_axis = 1. "
188  "Got different dims for inputs 0 and ",
189  i,
190  ". At dim: ",
191  j);
192  }
193  }
194  out_shape.insert(out_shape.begin() + canonical_axis, in.size());
195  } else {
196  for (int i = 1; i < in.size(); ++i) {
197  CAFFE_ENFORCE_EQ(
198  in[0].dims().size(),
199  in[i].dims().size(),
200  "All inputs of Concat should have same dims except "
201  "canonical_axis dim that is equal to ",
202  canonical_axis,
203  "Got different sizes for inputs 0 and ",
204  i);
205  for (int j = 0; j < in[0].dims().size(); ++j) {
206  if (j == canonical_axis) {
207  continue;
208  }
209  CAFFE_ENFORCE_EQ(
210  in[0].dims(j),
211  in[i].dims(j),
212  "All inputs of Concat should have same dims except "
213  "canonical_axis dim that is equal to ",
214  canonical_axis,
215  "Got different dims for inputs 0 and ",
216  i,
217  ". At dim: ",
218  j);
219  }
220  }
221 
222  for (int i = 1; i < in.size(); ++i) {
223  out_shape[canonical_axis] += in[i].dims(canonical_axis);
224  }
225  }
226  if (def.output_size() == 1) {
227  return vector<TensorShape>{CreateTensorShape(out_shape, in[0].data_type())};
228  }
229  return vector<TensorShape>{
230  CreateTensorShape(out_shape, in[0].data_type()),
231  CreateTensorShape(split_shape, TensorProto::INT32)};
232 }
233 
234 REGISTER_CPU_OPERATOR(Concat, ConcatOp<CPUContext>);
235 OPERATOR_SCHEMA(Concat)
236  .NumInputs(1, INT_MAX)
237  .NumOutputs(2)
238  .Arg("axis", "*(type: int; default: -1)* Axis to concatenate on.")
239  .Arg(
240  "order",
241  "*(type: string; default='NCHW')* Order of blob dimensions. Concats on the C dimension.")
242  .Arg(
243  "add_axis",
244  "*(type: int)* Pass non-zero integer to add the axis specified in `axis` to all input tensors.")
245  .TensorInferenceFunction(
246  OpSchema::NeedsAllInputShapes(TensorInferenceForConcat))
247  .CostInferenceFunction(CostInferenceForConcat)
248  .DeviceInferenceFunction(concatOpDevInfer)
249  .SetDoc(R"DOC(
250 Concatenate a list of tensors into a single tensor. Similar functionality to
251 Numpy's [concatenate](https://docs.scipy.org/doc/numpy/reference/generated/numpy.concatenate.html)
252 function. The `axis` argument specifies what axis along which the arrays will be concatenated.
253 When set to non-zero (default=0), the `add_axis` argument adds the axis specified in `axis` to
254 all input tensors.
255 
256 Github Links:
257 
258 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/concat_split_op.cc
259 - https://github.com/pytorch/pytorch/blob/master/caffe2/operators/concat_split_op.h
260 
261 
262 <details>
263 
264 <summary> <b>Example</b> </summary>
265 
266 **Code**
267 
268 ```
269 
270 workspace.ResetWorkspace()
271 
272 op = core.CreateOperator(
273  "Concat",
274  ["X1", "X2"],
275  ["Y", "split_info"],
276  axis=0
277 )
278 
279 workspace.FeedBlob("X1", np.array([[1,2],[3,4]]))
280 workspace.FeedBlob("X2", np.array([[5,6]]))
281 print("X1:", workspace.FetchBlob("X1"))
282 print("X2:", workspace.FetchBlob("X2"))
283 workspace.RunOperatorOnce(op)
284 print("Y:", workspace.FetchBlob("Y"))
285 print("split_info:", workspace.FetchBlob("split_info"))
286 
287 ```
288 
289 **Result**
290 
291 ```
292 
293 X1: [[1 2]
294  [3 4]]
295 X2: [[5 6]]
296 Y: [[1 2]
297  [3 4]
298  [5 6]]
299 split_info: [2 1]
300 
301 ```
302 
303 </details>
304 
305 <details>
306 
307 <summary> <b>Example 2</b> </summary>
308 
309 **Code**
310 
311 ```
312 
313 workspace.ResetWorkspace()
314 
315 op = core.CreateOperator(
316  "Concat",
317  ["X1", "X2"],
318  ["Y", "split_info"],
319  add_axis=1,
320  axis=3
321 )
322 
323 workspace.FeedBlob("X1", np.random.randint(10, size=(1, 1, 5, 5))) // NCHW
324 workspace.FeedBlob("X2", np.random.randint(10, size=(1, 1, 5, 5))) // NCHW
325 print("X1:", workspace.FetchBlob("X1"))
326 print("X2:", workspace.FetchBlob("X2"))
327 workspace.RunOperatorOnce(op)
328 print("Y:", workspace.FetchBlob("Y"))
329 print("split_info:", workspace.FetchBlob("split_info"))
330 
331 ```
332 
333 **Result**
334 
335 ```
336 
337 X1: [[[[1 8 3 9 0]
338  [6 4 6 5 6]
339  [3 9 1 9 9]
340  [5 1 0 7 7]
341  [9 4 0 0 9]]]]
342 X2: [[[[7 0 2 6 1]
343  [3 9 4 0 3]
344  [5 3 8 9 4]
345  [3 4 2 1 0]
346  [0 8 8 8 1]]]]
347 Y: [[[[[1 8 3 9 0]
348  [7 0 2 6 1]]
349 
350  [[6 4 6 5 6]
351  [3 9 4 0 3]]
352 
353  [[3 9 1 9 9]
354  [5 3 8 9 4]]
355 
356  [[5 1 0 7 7]
357  [3 4 2 1 0]]
358 
359  [[9 4 0 0 9]
360  [0 8 8 8 1]]]]]
361 split_info: [1 1]
362 
363 ```
364 
365 </details>
366 
367  )DOC")
368  .Input(0, "X1, X2, ...", "*(type: Tensor`<float>`)* List of input tensors.")
369  .Output(
370  0,
371  "concat_result",
372  "*(type: Tensor`<float>`)* Concatenated tensor.")
373  .Output(
374  1,
375  "split_info",
376  "*(type: Tensor`<int>`)* The dimensions of the inputs.")
377  .InheritOnnxSchema();
378 
379 // Backward compatibility names.
380 REGISTER_CPU_OPERATOR(DepthSplit, SplitOp<CPUContext>);
381 REGISTER_CPU_OPERATOR(DepthConcat, ConcatOp<CPUContext>);
382 OPERATOR_SCHEMA(DepthSplit)
383  .NumInputs(1, 2)
384  .NumOutputs(1, INT_MAX)
385  .SetDoc("Backward compatible operator name for Split.");
386 OPERATOR_SCHEMA(DepthConcat)
387  .NumInputs(1, INT_MAX)
388  .NumOutputs(2)
389  .SetDoc("Backward compatible operator name for Concat.");
390 
391 class GetSplitGradient : public GradientMakerBase {
392  using GradientMakerBase::GradientMakerBase;
393  vector<OperatorDef> GetGradientDefs() override {
394  vector<string> output_grads;
395  for (int i = 0; i < def_.output_size(); ++i) {
396  if (!GradOut(i).IsEmpty()) {
397  output_grads.push_back(GO(i));
398  }
399  }
400  if (output_grads.empty()) {
401  return {};
402  }
403  return SingleGradientDef(
404  "Concat",
405  "",
406  output_grads,
407  vector<string>{GI(0), "_" + GI(0) + "_dims"});
408  }
409 };
410 REGISTER_GRADIENT(Split, GetSplitGradient);
411 REGISTER_GRADIENT(DepthSplit, GetSplitGradient);
412 REGISTER_GRADIENT(SplitByLengths, GetSplitGradient);
413 
414 class GetConcatGradient : public GradientMakerBase {
415  using GradientMakerBase::GradientMakerBase;
416  vector<OperatorDef> GetGradientDefs() override {
417  if (GradOut(0).IsEmpty()) {
418  return {};
419  }
420  vector<string> grads;
421  for (int i = 0; i < def_.input_size(); ++i) {
422  grads.push_back(GI(i));
423  }
424  return SingleGradientDef("Split", "", vector<string>{GO(0), O(1)}, grads);
425  }
426 };
427 REGISTER_GRADIENT(Concat, GetConcatGradient);
428 REGISTER_GRADIENT(DepthConcat, GetConcatGradient);
429 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
static TensorInferenceFunctionType NeedsAllInputShapes(TensorInferenceFunctionType f)
A wrapper that makes an infer tensor function to return unknown shape for all outputs if any one of t...