Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_gather_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/batch_gather.h"
3 #include "caffe2/utils/math.h"
4 #include "caffe2/core/tensor.h"
5 
7 using caffe2::Tensor;
8 using std::vector;
9 
10 namespace caffe2 {
11 namespace {
12 
13 template <class TInd>
14 void batch_gather_op_cpu_impl(
15  const at::Tensor& data_,
16  const at::Tensor& indices_,
17  const at::Tensor& output_) {
18  Tensor data{C10Tensor(data_)};
19  Tensor indices{C10Tensor(indices_)};
20  Tensor output{C10Tensor(output_)};
21  CPUContext context;
22 
23  CAFFE_ENFORCE_GE(data.dim(), 2, "DATA should be at least 2-D");
24 
25  vector<int64_t> shape;
26  shape.push_back(data.size(0));
27  shape.insert(shape.end(), indices.sizes().begin(), indices.sizes().end());
28  shape.insert(shape.end(), data.sizes().begin() + 2, data.sizes().end());
29  output.Resize(shape);
30 
31  auto block_size = data.size_from_dim(2);
32  auto block_bytesize = block_size * data.dtype().itemsize();
33  auto N = indices.numel();
34  auto data_batch_bytesize = data.size_from_dim(1) * data.dtype().itemsize();
35  auto gathered_batch_bytesize =
36  N * data.size_from_dim(2) * data.dtype().itemsize();
37  const TInd* idxs = indices.template data<TInd>();
38  auto src_base = static_cast<const char*>(data.raw_data());
39  auto out = static_cast<char*>(output.raw_mutable_data(data.dtype()));
40 
41  for (auto batch = 0; batch < data.size(0); ++batch) {
42  for (auto i = 0; i < N; ++i) {
43  auto idx = idxs[i];
44  CAFFE_ENFORCE(
45  0 <= idx && idx < data.size(1),
46  "INDICES element is out of DATA bounds, id=",
47  idx,
48  " data_dim=",
49  data.size(1));
50  auto src = src_base + idx * block_bytesize + batch * data_batch_bytesize;
51  auto dst = out + i * block_bytesize + batch * gathered_batch_bytesize;
52  context.CopyItemsSameDevice(data.dtype(), block_size, src, dst);
53  }
54  }
55 }
56 
57 void batch_gather_op_cpu(const at::Tensor& data,
58  const at::Tensor& indices,
59  const at::Tensor& output) {
60  switch (data.scalar_type()) {
61  case ScalarType::Int: return batch_gather_op_cpu_impl<int>(data, indices, output);
62  case ScalarType::Long: return batch_gather_op_cpu_impl<int64_t>(data, indices, output);
63  default: throw std::runtime_error(string() + "Unsupported dtype: " + toString(data.scalar_type()));
64  }
65 }
66 } // namespace
67 } // namespace caffe2
68 
69 namespace c10 {
70 C10_REGISTER_KERNEL(caffe2::ops::BatchGather)
71  .kernel<decltype(caffe2::batch_gather_op_cpu), &caffe2::batch_gather_op_cpu>()
72  .dispatchKey(CPUTensorId());
73 } // namespace c10
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
Virtual interface for the Context class in Caffe2.
Definition: context_base.h:32
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7