2 #include "caffe2/core/context.h" 3 #include "caffe2/core/operator.h" 4 #include "caffe2/perfkernels/embedding_lookup.h" 14 bool USE_POSITIONAL_WEIGHT = 0
21 template <
class... Args>
25 !(USE_WEIGHT & USE_MEAN),
"Cannot both specify weight and mean.");
28 ~CPUSparseLengthsReductionOp() {}
33 bool RunOnDevice()
override {
37 template <
typename InputType>
38 bool DoRunWithType() {
40 this,
Input(INDICES));
43 template <
typename InputType,
typename IndexType>
44 bool DoRunWithType2() {
45 auto& dataInput =
Input(DATA);
46 auto& indicesInput =
Input(INDICES);
47 auto& lengthsInput =
Input(LENGTHS);
49 CAFFE_ENFORCE_EQ(1, indicesInput.dim(),
"INDICES must be a vector");
50 CAFFE_ENFORCE_EQ(1, lengthsInput.dim(),
"LENGTHS must be a vector");
51 const int64_t N = dataInput.size(0);
52 const int D = dataInput.size_from_dim(1);
53 const int64_t
M = lengthsInput.size(0);
54 const int64_t indices_size = indicesInput.numel();
56 auto shape = dataInput.sizes().vec();
58 auto* output = Output(0, shape, at::dtype<T>());
59 T* out_data = output->template mutable_data<T>();
61 const InputType* in_data = dataInput.template data<InputType>();
62 const IndexType* indices = indicesInput.template data<IndexType>();
63 const int* lengths = lengthsInput.template data<int>();
64 const T* in_weight =
nullptr;
68 auto& weightInput =
Input(WEIGHT);
69 CAFFE_ENFORCE_EQ(1, weightInput.dim(),
"WEIGHT must be a vector");
70 if (!USE_POSITIONAL_WEIGHT) {
74 "Weight should have the same length as indices.");
76 in_weight = weightInput.template data<T>();
80 EmbeddingLookup<IndexType, InputType, T, USE_POSITIONAL_WEIGHT>(
98 INDICES = 1 + USE_WEIGHT,
100 LENGTHS = 2 + USE_WEIGHT,
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
const Tensor & Input(int idx, DeviceType type=CPUContext::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...