1 #include "caffe2/perfkernels/embedding_lookup.h" 3 #include "caffe2/core/types.h" 4 #include "caffe2/perfkernels/common.h" 5 #include "caffe2/perfkernels/typed_axpy.h" 6 #include "caffe2/utils/eigen_utils.h" 7 #include "caffe2/utils/math.h" 19 bool IS_WEIGHT_POSITIONAL =
false>
20 static bool EmbeddingLookupGenericSlow(
21 const int64_t block_size,
22 const int64_t output_size,
23 const int64_t index_size,
24 const int64_t data_size,
26 const IndexType* indices,
29 const float* scale_bias,
30 bool normalize_by_lengths,
33 for (
int m = 0; m < output_size; ++m) {
34 memset(out, 0,
sizeof(OutType) * block_size);
35 EigenVectorArrayMap<OutType> out_vector(out, block_size);
36 if (current + lengths[m] > index_size) {
39 for (
int i = 0; i < lengths[m]; ++i) {
40 int64_t idx = indices[current];
41 if (idx < 0 || idx >= data_size) {
45 if (current + 1 < index_size) {
46 __builtin_prefetch(input + block_size * indices[current + 1], 0, 1);
50 float w = 1.f, b = 0.f;
52 w = weights[IS_WEIGHT_POSITIONAL ? i : current];
55 b = w * scale_bias[2 * indices[current] + 1];
56 w = w * scale_bias[2 * indices[current]];
59 TypedAxpy<InType, OutType>(
60 block_size, w, input + block_size * indices[current], out);
63 out_vector = out_vector + b;
68 if (normalize_by_lengths && lengths[m]) {
70 math::Scale<float, OutType, CPUContext>(
71 block_size, 1.f / lengths[m], out, out,
nullptr);
75 return current == index_size;
79 #define EMBEDDING_SPECIALIZATION( \ 80 IndexType, InTypeName, InType, OutType, IS_WEIGHT_POSITIONAL) \ 82 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base( \ 83 const int64_t block_size, \ 84 const int64_t output_size, \ 85 const int64_t index_size, \ 86 const int64_t data_size, \ 87 const InType* input, \ 88 const IndexType* indices, \ 90 const float* weights, \ 91 const float* scale_bias, \ 92 bool normalize_by_lengths, \ 94 return EmbeddingLookupGenericSlow< \ 98 IS_WEIGHT_POSITIONAL>( \ 108 normalize_by_lengths, \ 112 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__base) \ 113 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL##__avx2_fma; \ 115 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ 116 const int64_t block_size, \ 117 const int64_t output_size, \ 118 const int64_t index_size, \ 119 const int64_t data_size, \ 120 const InType* input, \ 121 const IndexType* indices, \ 122 const int* lengths, \ 123 const float* weights, \ 124 const float* scale_bias, \ 125 bool normalize_by_lengths, \ 127 if (std::is_same<InType, uint8_t>::value) { \ 128 CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr"); \ 130 CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr"); \ 133 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ 143 normalize_by_lengths, \ 146 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL, \ 156 normalize_by_lengths, \ 160 void EmbeddingLookup<IndexType, InType, OutType, IS_WEIGHT_POSITIONAL>( \ 161 const int64_t block_size, \ 162 const int64_t output_size, \ 163 const int64_t index_size, \ 164 const int64_t data_size, \ 165 const InType* input, \ 166 const IndexType* indices, \ 167 const int* lengths, \ 168 const float* weights, \ 169 const float* scale_bias, \ 170 bool normalize_by_lengths, \ 173 EmbeddingLookup_##IndexType##_##InTypeName##_##OutType##_##IS_WEIGHT_POSITIONAL( \ 183 normalize_by_lengths, \ 188 int64_t current = 0; \ 189 for (int m = 0; m < output_size; ++m) { \ 190 for (int i = 0; i < lengths[m]; ++i) { \ 191 CAFFE_ENFORCE_LT(current, index_size); \ 192 IndexType idx = indices[current]; \ 194 0 <= idx && idx < data_size, \ 197 " is out of bounds: ", \ 207 "Your input seems to be incorrect: the sum of lengths values should be " \ 208 "the size of the indices tensor, but it appears not."); \ 211 EMBEDDING_SPECIALIZATION(int32_t,
float,
float,
float,
false);
212 EMBEDDING_SPECIALIZATION(int64_t,
float,
float,
float,
false);
213 EMBEDDING_SPECIALIZATION(int32_t, half,
at::Half,
float,
false);
214 EMBEDDING_SPECIALIZATION(int64_t, half,
at::Half,
float,
false);
215 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t,
float,
false);
216 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t,
float,
false);
218 EMBEDDING_SPECIALIZATION(int32_t,
float,
float,
float,
true);
219 EMBEDDING_SPECIALIZATION(int64_t,
float,
float,
float,
true);
220 EMBEDDING_SPECIALIZATION(int32_t, half,
at::Half,
float,
true);
221 EMBEDDING_SPECIALIZATION(int64_t, half,
at::Half,
float,
true);
222 EMBEDDING_SPECIALIZATION(int32_t, uint8_t, uint8_t,
float,
true);
223 EMBEDDING_SPECIALIZATION(int64_t, uint8_t, uint8_t,
float,
true);
225 #undef EMBEDDING_SPECIALIZATION
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...