1 #ifndef CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_ 2 #define CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_ 4 #include "caffe2/core/context.h" 5 #include "caffe2/core/logging.h" 6 #include "caffe2/core/operator.h" 7 #include "caffe2/operators/fused_rowwise_8bit_conversion_ops.h" 8 #include "caffe2/operators/reducer_functors.h" 9 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" 10 #include "caffe2/utils/math.h" 14 template <
class Context,
bool with_weights = 0,
bool is_mean = 0>
18 !(with_weights && is_mean),
19 "Cannot have with_weights and is_mean a the same time");
21 USE_OPERATOR_CONTEXT_FUNCTIONS;
24 bool RunOnDevice()
override {
26 this,
Input(INDICES));
29 template <
typename IndexType>
30 bool DoRunWithType() {
31 const auto& data =
Input(DATA);
32 const auto& indices =
Input(INDICES);
33 const auto& lengths =
Input(LENGTHS);
35 CAFFE_ENFORCE_EQ(indices.dim(), 1,
"INDICES must be a vector");
36 CAFFE_ENFORCE_EQ(lengths.dim(), 1,
"LENGTHS must be a vector");
38 const float* weights =
nullptr;
40 const auto& weights_input =
Input(WEIGHTS);
41 CAFFE_ENFORCE_EQ(weights_input.dim(), 1,
"WEIGHTS must be a vector");
43 weights_input.numel(),
45 "WEIGHTS should have the same length as INDICES.");
46 weights = weights_input.template data<float>();
49 CAFFE_ENFORCE_GT(data.size(1), 8,
"DATA must have more than 8 columns");
52 const std::vector<int64_t> shape = {lengths.size(0), data.size(1) - 8};
53 auto* output = Output(0, shape, at::dtype<float>());
60 data.template data<uint8_t>(),
61 indices.template data<IndexType>(),
62 lengths.template data<int>(),
65 output->template mutable_data<float>());
73 INDICES = 1 + with_weights,
74 LENGTHS = 2 + with_weights,
80 #endif // CAFFE2_OPERATORS_LENGTHS_REDUCER_FUSED_8BIT_ROWWISE_OPS_H_
const Tensor & Input(int idx, DeviceType type=Context::GetDeviceType())
Retrieve a non-owning reference to the input at position 'idx' for this operator. ...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void Fused8BitRowwiseEmbeddingLookup(const std::int64_t block_size, const std::int64_t output_size, const std::int64_t index_size, const std::int64_t data_size, const InType *input, const IndexType *indices, const int *lengths, const float *weights, bool normalize_by_lengths, OutType *out)
Embedding lookup with reduction.