Caffe2 - C++ API
A deep learning, cross platform ML framework
sparse_lengths_sum_cpu.cc
1 #include <ATen/core/dispatch/KernelRegistration.h>
2 #include "caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h"
3 #include "caffe2/perfkernels/embedding_lookup.h"
4 #include "caffe2/utils/math.h"
5 #include "caffe2/core/tensor.h"
6 
7 using caffe2::Tensor;
8 
9 namespace caffe2 {
10 namespace {
11 
12 template <typename InputType, typename IndexType>
13 void sparse_lengths_sum_op_cpu_impl_(
14  const at::Tensor& dataInput_,
15  const at::Tensor& indicesInput_,
16  const at::Tensor& lengthsInput_,
17  const at::Tensor& output_) {
18  Tensor dataInput{C10Tensor(dataInput_)};
19  Tensor indicesInput{C10Tensor(indicesInput_)};
20  Tensor lengthsInput{C10Tensor(lengthsInput_)};
21  Tensor output{C10Tensor(output_)};
22 
23  using T = float;
24  constexpr bool USE_MEAN = false;
25  constexpr bool USE_POSITIONAL_WEIGHT = false;
26 
27  CAFFE_ENFORCE_EQ(1, indicesInput.dim(), "INDICES must be a vector");
28  CAFFE_ENFORCE_EQ(1, lengthsInput.dim(), "LENGTHS must be a vector");
29  const int64_t N = dataInput.size(0);
30  const int D = dataInput.size_from_dim(1);
31  const int64_t M = lengthsInput.size(0);
32  const int64_t indices_size = indicesInput.numel();
33 
34  auto shape = dataInput.sizes().vec();
35  shape[0] = M;
36  output.Resize(shape);
37  T* out_data = output.template mutable_data<T>();
38 
39  const InputType* in_data = dataInput.template data<InputType>();
40  const IndexType* indices = indicesInput.template data<IndexType>();
41  const int* lengths = lengthsInput.template data<int>();
42  const T* in_weight = nullptr;
43 
44  // delegate work to perfkernel that branches based on architecture
45  caffe2::EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
46  D,
47  M,
48  indices_size,
49  N,
50  in_data,
51  indices,
52  lengths,
53  in_weight,
54  nullptr, // scale_bias field is only used in SparseLengths8BitsRowwiseOp
55  USE_MEAN,
56  out_data);
57 }
58 
59 template<typename IndexType>
60 void sparse_lengths_sum_op_cpu_impl(
61  const at::Tensor& dataInput,
62  const at::Tensor& indicesInput,
63  const at::Tensor& lengthsInput,
64  const at::Tensor& output) {
65  switch (dataInput.scalar_type()) {
66  case ScalarType::Float: return sparse_lengths_sum_op_cpu_impl_<float, IndexType>(dataInput, indicesInput, lengthsInput, output);
67  case ScalarType::Half: return sparse_lengths_sum_op_cpu_impl_<at::Half, IndexType>(dataInput, indicesInput, lengthsInput, output);
68  default: throw std::runtime_error(string() + "Unsupported dtype for input data " + toString(dataInput.scalar_type()));
69  }
70 }
71 
72 void sparse_lengths_sum_op_cpu(
73  const at::Tensor& dataInput,
74  const at::Tensor& indicesInput,
75  const at::Tensor& lengthsInput,
76  const at::Tensor& output) {
77  switch (indicesInput.scalar_type()) {
78  case ScalarType::Int: return sparse_lengths_sum_op_cpu_impl<int>(dataInput, indicesInput, lengthsInput, output);
79  case ScalarType::Long: return sparse_lengths_sum_op_cpu_impl<int64_t>(dataInput, indicesInput, lengthsInput, output);
80  default: throw std::runtime_error(string() + "Unsupported dtype for input indices " + toString(dataInput.scalar_type()));
81  }
82 }
83 
84 } // namespace
85 } // namespace caffe2
86 
87 namespace c10 {
88 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
89  .kernel<decltype(caffe2::sparse_lengths_sum_op_cpu), &caffe2::sparse_lengths_sum_op_cpu>()
90  .dispatchKey(CPUTensorId());
91 } // namespace c10
Definition: any.cpp:108
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
Definition: tensor.h:25
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...
Definition: alias_info.h:7
Definition: static.cpp:70