1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/concat.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 14 template <
class DataType,
class Context>
15 void concat_op_cpu_impl(
16 ArrayRef<at::Tensor> inputs,
21 Tensor output{C10Tensor(output_)};
22 Tensor split{C10Tensor(split_)};
25 split.Resize(vector<int64_t>(1, inputs.size()));
26 int* axis_data = split.template mutable_data<int>();
27 int adj_size =
Tensor(inputs[0]).
dim() + (add_axis ? 1 : 0);
28 int canonical_axis = caffe2::canonical_axis_index_(axis, adj_size);
29 CAFFE_ENFORCE_LT(canonical_axis, adj_size,
"Axis not in input ndim range.");
30 for (
size_t i = 1; i < inputs.size(); ++i) {
33 "All inputs must have the same type, expected: ",
34 Tensor(inputs[0]).dtype().name(),
36 Tensor(inputs[i]).dtype().name(),
41 int before = 1, after = 1;
42 vector<int64_t> output_dims(
Tensor(inputs[0]).sizes().vec());
43 for (
int i = 0; i <
Tensor(inputs[0]).
dim(); ++i) {
44 if (i == canonical_axis && !add_axis) {
48 if (i < canonical_axis) {
54 for (
size_t j = 1; j < inputs.size(); ++j) {
58 "Expect dimension = ",
66 ". The input tensors can only have different dimensions " 67 "when arg 'add_axis' = 0 and along the axis = ",
77 int output_channels = 0;
78 for (
size_t i = 0; i < inputs.size(); ++i) {
79 axis_data[i] = add_axis ? 1 :
Tensor(inputs[i]).
dim32(canonical_axis);
80 output_channels += axis_data[i];
83 output_dims.insert(output_dims.begin() + canonical_axis, output_channels);
85 output_dims[canonical_axis] = output_channels;
87 output.Resize(output_dims);
88 size_t output_offset = 0;
89 for (
size_t i = 0; i < inputs.size(); ++i) {
91 auto axis_dim = add_axis ? 1 : input.dim32(canonical_axis);
92 caffe2::math::CopyMatrix<Context>(
98 static_cast<char*
>(output.raw_mutable_data(
Tensor(inputs[0]).dtype())) +
100 output_channels * after,
101 static_cast<Context*>(&context),
103 output_offset += axis_dim * after * input.itemsize();
110 C10_REGISTER_KERNEL(caffe2::ops::Concat)
111 .kernel<decltype(caffe2::concat_op_cpu_impl<float, CPUContext>), &caffe2::concat_op_cpu_impl<float, CPUContext>>()
112 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
const TypeMeta & dtype() const
Returns the TypeMeta object associated with the current data type.
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
int dim() const
Returns the number of dimensions of the data.
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
int dim32(const int i) const
Returns the i-th dimension of the tensor in int.