1 #include "caffe2/perfkernels/fused_8bit_rowwise_embedding_lookup.h" 3 #include "caffe2/core/types.h" 4 #include "caffe2/perfkernels/common.h" 5 #include "caffe2/perfkernels/typed_axpy.h" 6 #include "caffe2/utils/cpuid.h" 7 #include "caffe2/utils/eigen_utils.h" 8 #include "caffe2/utils/math.h" 20 bool IS_WEIGHT_POSITIONAL =
false>
21 static bool Fused8BitRowwiseEmbeddingLookupGenericSlow(
22 const int64_t block_size,
23 const int64_t output_size,
24 const int64_t index_size,
25 const int64_t data_size,
27 const IndexType* indices,
30 bool normalize_by_lengths,
34 const auto scale_bias_offset = 8 /
sizeof(InType);
35 const int64_t fused_block_size = block_size + scale_bias_offset;
37 for (
int m = 0; m < output_size; ++m) {
38 memset(out, 0,
sizeof(OutType) * block_size);
39 EigenVectorArrayMap<OutType> out_vector(out, block_size);
40 if (current + lengths[m] > index_size) {
43 for (
int i = 0; i < lengths[m]; ++i) {
44 int64_t idx = indices[current];
45 if (idx < 0 || idx >= data_size) {
49 if (current + 1 < index_size) {
51 input + fused_block_size * indices[current + 1], 0, 1);
55 const float* scale_bias =
reinterpret_cast<const float*
>(
56 input + fused_block_size * indices[current] + block_size);
60 weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
62 const float scale = weight * scale_bias[0];
63 const float bias = weight * scale_bias[1];
65 TypedAxpy<InType, OutType>(
66 block_size, scale, input + fused_block_size * indices[current], out);
72 if (normalize_by_lengths && lengths[m]) {
74 math::Scale<float, OutType, CPUContext>(
75 block_size, 1.f / lengths[m], out, out,
nullptr);
79 return current == index_size;
83 #define FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(IndexType, OutType) \ 85 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base( \ 86 const int64_t block_size, \ 87 const int64_t output_size, \ 88 const int64_t index_size, \ 89 const int64_t data_size, \ 90 const uint8_t* input, \ 91 const IndexType* indices, \ 93 const float* weights, \ 94 bool normalize_by_lengths, \ 96 return Fused8BitRowwiseEmbeddingLookupGenericSlow< \ 109 normalize_by_lengths, \ 113 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__base) \ 114 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false__avx2_fma; \ 115 bool Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ 116 const int64_t block_size, \ 117 const int64_t output_size, \ 118 const int64_t index_size, \ 119 const int64_t data_size, \ 120 const uint8_t* input, \ 121 const IndexType* indices, \ 122 const int* lengths, \ 123 const float* weights, \ 124 bool normalize_by_lengths, \ 126 const int32_t one = 1; \ 128 reinterpret_cast<const uint8_t*>(&one)[0], \ 130 "Fused8BitRowwiseEmbeddingLookup is not supported on this platform"); \ 132 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ 141 normalize_by_lengths, \ 144 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType##_false, \ 153 normalize_by_lengths, \ 157 void Fused8BitRowwiseEmbeddingLookup<IndexType, uint8_t, OutType, false>( \ 158 const int64_t block_size, \ 159 const int64_t output_size, \ 160 const int64_t index_size, \ 161 const int64_t data_size, \ 162 const uint8_t* input, \ 163 const IndexType* indices, \ 164 const int* lengths, \ 165 const float* weights, \ 166 bool normalize_by_lengths, \ 169 Fused8BitRowwiseEmbeddingLookup_##IndexType##_uint8_t_##OutType( \ 178 normalize_by_lengths, \ 183 int64_t current = 0; \ 184 for (int m = 0; m < output_size; ++m) { \ 185 for (int i = 0; i < lengths[m]; ++i) { \ 186 CAFFE_ENFORCE_LT(current, index_size); \ 187 IndexType idx = indices[current]; \ 189 0 <= idx && idx < data_size, \ 192 " is out of bounds: ", \ 202 "Your input seems to be incorrect: the sum of lengths values should be " \ 203 "the size of the indices tensor, but it appears not."); \ 206 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int32_t,
float);
207 FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION(int64_t,
float);
209 #undef FUSED_8BIT_ROWWISE_EMBEDDING_SPECIALIZATION A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...