1 #include <ATen/core/dispatch/KernelRegistration.h> 2 #include "caffe2/operators/experimental/c10/schemas/sparse_lengths_sum.h" 3 #include "caffe2/perfkernels/embedding_lookup.h" 4 #include "caffe2/utils/math.h" 5 #include "caffe2/core/tensor.h" 12 template <
typename InputType,
typename IndexType>
13 void sparse_lengths_sum_op_cpu_impl_(
18 Tensor dataInput{C10Tensor(dataInput_)};
19 Tensor indicesInput{C10Tensor(indicesInput_)};
20 Tensor lengthsInput{C10Tensor(lengthsInput_)};
21 Tensor output{C10Tensor(output_)};
24 constexpr
bool USE_MEAN =
false;
25 constexpr
bool USE_POSITIONAL_WEIGHT =
false;
27 CAFFE_ENFORCE_EQ(1, indicesInput.dim(),
"INDICES must be a vector");
28 CAFFE_ENFORCE_EQ(1, lengthsInput.dim(),
"LENGTHS must be a vector");
29 const int64_t N = dataInput.size(0);
30 const int D = dataInput.size_from_dim(1);
31 const int64_t
M = lengthsInput.size(0);
32 const int64_t indices_size = indicesInput.numel();
34 auto shape = dataInput.sizes().vec();
37 T* out_data = output.template mutable_data<T>();
39 const InputType* in_data = dataInput.template data<InputType>();
40 const IndexType* indices = indicesInput.template data<IndexType>();
41 const int* lengths = lengthsInput.template data<int>();
42 const T* in_weight =
nullptr;
45 caffe2::EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
59 template<
typename IndexType>
60 void sparse_lengths_sum_op_cpu_impl(
65 switch (dataInput.scalar_type()) {
66 case ScalarType::Float:
return sparse_lengths_sum_op_cpu_impl_<float, IndexType>(dataInput, indicesInput, lengthsInput, output);
67 case ScalarType::Half:
return sparse_lengths_sum_op_cpu_impl_<at::Half, IndexType>(dataInput, indicesInput, lengthsInput, output);
68 default:
throw std::runtime_error(
string() +
"Unsupported dtype for input data " + toString(dataInput.scalar_type()));
72 void sparse_lengths_sum_op_cpu(
77 switch (indicesInput.scalar_type()) {
78 case ScalarType::Int:
return sparse_lengths_sum_op_cpu_impl<int>(dataInput, indicesInput, lengthsInput, output);
79 case ScalarType::Long:
return sparse_lengths_sum_op_cpu_impl<int64_t>(dataInput, indicesInput, lengthsInput, output);
80 default:
throw std::runtime_error(
string() +
"Unsupported dtype for input indices " + toString(dataInput.scalar_type()));
88 C10_REGISTER_KERNEL(caffe2::ops::SparseLengthsSum)
89 .kernel<decltype(caffe2::sparse_lengths_sum_op_cpu), &caffe2::sparse_lengths_sum_op_cpu>()
90 .dispatchKey(CPUTensorId());
Tensor class holds a shared pointer to the implementation TensorImpl, redirects API calls to TensorIm...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
To register your own kernel for an operator, do in one (!) cpp file: C10_REGISTER_KERNEL(OperatorHand...