1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/batch_gather.h" 3 #include "caffe2/utils/math.h" 4 #include "caffe2/core/tensor.h" 14 void batch_gather_op_cpu_impl(
18 Tensor data{C10Tensor(data_)};
19 Tensor indices{C10Tensor(indices_)};
20 Tensor output{C10Tensor(output_)};
23 CAFFE_ENFORCE_GE(data.dim(), 2,
"DATA should be at least 2-D");
25 vector<int64_t> shape;
26 shape.push_back(data.size(0));
27 shape.insert(shape.end(), indices.sizes().begin(), indices.sizes().end());
28 shape.insert(shape.end(), data.sizes().begin() + 2, data.sizes().end());
31 auto block_size = data.size_from_dim(2);
32 auto block_bytesize = block_size * data.dtype().itemsize();
33 auto N = indices.numel();
34 auto data_batch_bytesize = data.size_from_dim(1) * data.dtype().itemsize();
35 auto gathered_batch_bytesize =
36 N * data.size_from_dim(2) * data.dtype().itemsize();
37 const TInd* idxs = indices.template data<TInd>();
38 auto src_base =
static_cast<const char*
>(data.raw_data());
39 auto out =
static_cast<char*
>(output.raw_mutable_data(data.dtype()));
41 for (
auto batch = 0; batch < data.size(0); ++batch) {
42 for (
auto i = 0; i < N; ++i) {
45 0 <= idx && idx < data.size(1),
46 "INDICES element is out of DATA bounds, id=",
50 auto src = src_base + idx * block_bytesize + batch * data_batch_bytesize;
51 auto dst = out + i * block_bytesize + batch * gathered_batch_bytesize;
52 context.CopyItemsSameDevice(data.dtype(), block_size, src, dst);
57 void batch_gather_op_cpu(
const at::Tensor& data,
60 switch (data.scalar_type()) {
61 case ScalarType::Int:
return batch_gather_op_cpu_impl<int>(data, indices, output);
62 case ScalarType::Long:
return batch_gather_op_cpu_impl<int64_t>(data, indices, output);
63 default:
throw std::runtime_error(
string() +
"Unsupported dtype: " + toString(data.scalar_type()));
70 C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
71 .kernel<decltype(caffe2::batch_gather_op_cpu), &caffe2::batch_gather_op_cpu>()
72 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Virtual interface for the Context class in Caffe2.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...