Caffe2 - C++ API
A deep learning, cross platform ML framework
concat_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/concat.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
8 using caffe2::Tensor;
10 using std::vector;
11 
12 namespace caffe2 {
13 namespace {
14 template <class DataType, class Context>
15 void concat_op_cpu_impl(
16  ArrayRef<at::Tensor> inputs,
17  const at::Tensor& output_,
18  const at::Tensor& split_,
19  int64_t axis,
20  int64_t add_axis) {
21  Tensor output{C10Tensor(output_)};
22  Tensor split{C10Tensor(split_)};
23  CPUContext context;
24 
25  split.Resize(vector<int64_t>(1, inputs.size()));
26  int* axis_data = split.template mutable_data<int>();
27  int adj_size = Tensor(inputs[0]).dim() + (add_axis ? 1 : 0);
28  int canonical_axis = caffe2::canonical_axis_index_(axis, adj_size);
29  CAFFE_ENFORCE_LT(canonical_axis, adj_size, "Axis not in input ndim range.");
30  for (size_t i = 1; i < inputs.size(); ++i) {
31  CAFFE_ENFORCE(
32  Tensor(inputs[i]).dtype() == Tensor(inputs[0]).dtype(),
33  "All inputs must have the same type, expected: ",
34  Tensor(inputs[0]).dtype().name(),
35  " but got: ",
36  Tensor(inputs[i]).dtype().name(),
37  " for input: ",
38  i);
39  }
40 
41  int before = 1, after = 1;
42  vector<int64_t> output_dims(Tensor(inputs[0]).sizes().vec());
43  for (int i = 0; i < Tensor(inputs[0]).dim(); ++i) {
44  if (i == canonical_axis && !add_axis) {
45  continue;
46  }
47  int dim = Tensor(inputs[0]).dim32(i);
48  if (i < canonical_axis) {
49  before *= dim;
50  } else { // i > canonical_axis || i == canonical_axis && add_axis
51  after *= dim;
52  }
53  // check the input dims are compatible.
54  for (size_t j = 1; j < inputs.size(); ++j) {
55  int dim_j = Tensor(inputs[j]).dim32(i);
56  CAFFE_ENFORCE(
57  dim == dim_j,
58  "Expect dimension = ",
59  dim,
60  " got ",
61  dim_j,
62  " at axis = ",
63  i,
64  " for input: ",
65  j,
66  ". The input tensors can only have different dimensions "
67  "when arg 'add_axis' = 0 and along the axis = ",
68  canonical_axis,
69  " <",
70  Tensor(inputs[0]).sizes(),
71  "> vs <",
72  Tensor(inputs[j]).sizes(),
73  ">.");
74  }
75  }
76 
77  int output_channels = 0;
78  for (size_t i = 0; i < inputs.size(); ++i) {
79  axis_data[i] = add_axis ? 1 : Tensor(inputs[i]).dim32(canonical_axis);
80  output_channels += axis_data[i];
81  }
82  if (add_axis) {
83  output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
84  } else {
85  output_dims[canonical_axis] = output_channels;
86  }
87  output.Resize(output_dims);
88  size_t output_offset = 0;
89  for (size_t i = 0; i < inputs.size(); ++i) {
90  Tensor input(inputs[i]);
91  auto axis_dim = add_axis ? 1 : input.dim32(canonical_axis);
92  caffe2::math::CopyMatrix<Context>(
93  input.itemsize(),
94  before,
95  axis_dim * after,
96  input.raw_data(),
97  axis_dim * after,
98  static_cast<char*>(output.raw_mutable_data(Tensor(inputs[0]).dtype())) +
99  output_offset,
100  output_channels * after,
101  static_cast<Context*>(&context),
102  Tensor(inputs[0]).dtype().copy());
103  output_offset += axis_dim * after * input.itemsize();
104  }
105 }
106 } // namespace
107 } // namespace caffe2
108 
109 namespace c10 {
110 C10_REGISTER_KERNEL(caffe2::ops::Concat)
111  .kernel<decltype(caffe2::concat_op_cpu_impl<float, CPUContext>), &caffe2::concat_op_cpu_impl<float, CPUContext>>()
112  .dispatchKey(CPUTensorId());
113 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
const TypeMeta & dtype() const
Returns the TypeMeta object associated with the current data type.
Definition: tensor.h:557
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
constexpr Copy * copy() const noexcept
Returns the typed copy function pointer for individual iterms.
Definition: typeid.h:380
int dim() const
Returns the number of dimensions of the data.
Definition: tensor.h:461
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.
Definition: tensor.h:576