8 #include <c10/util/Half.h> 12 template <
bool IS_WEIGHT_POSITIONAL>
13 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
14 const int64_t block_size,
15 const int64_t output_size,
16 const int64_t index_size,
17 const int64_t data_size,
22 bool normalize_by_lengths,
24 const int prefdist_T0 = 16;
25 const int fused_block_size = block_size + 2;
27 if (block_size == 128) {
29 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
30 float* op = &out[rangeIndex * block_size];
31 __m256 vop0 = _mm256_setzero_ps();
32 __m256 vop8 = _mm256_setzero_ps();
33 __m256 vop16 = _mm256_setzero_ps();
34 __m256 vop24 = _mm256_setzero_ps();
35 __m256 vop32 = _mm256_setzero_ps();
36 __m256 vop40 = _mm256_setzero_ps();
37 __m256 vop48 = _mm256_setzero_ps();
38 __m256 vop56 = _mm256_setzero_ps();
39 __m256 vop64 = _mm256_setzero_ps();
40 __m256 vop72 = _mm256_setzero_ps();
41 __m256 vop80 = _mm256_setzero_ps();
42 __m256 vop88 = _mm256_setzero_ps();
43 __m256 vop96 = _mm256_setzero_ps();
44 __m256 vop104 = _mm256_setzero_ps();
45 __m256 vop112 = _mm256_setzero_ps();
46 __m256 vop120 = _mm256_setzero_ps();
47 if (dataInd + lengths[rangeIndex] > index_size) {
50 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
52 const int idx = indices[dataInd];
53 if (idx < 0 || idx >= data_size) {
58 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
60 __m256 vwgt = _mm256_set1_ps(wgt);
61 const float* ip = &input[idx * fused_block_size];
62 const int next_T0 = (dataInd < index_size - prefdist_T0)
63 ? (dataInd + prefdist_T0)
65 const int idx_pref_T0 = indices[next_T0];
66 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
69 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
70 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
72 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
73 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
75 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
77 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
78 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
80 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
82 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
83 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
85 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
87 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
88 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
90 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
92 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
93 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
95 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
97 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
98 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
100 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
102 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
103 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
105 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
107 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
108 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
111 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
112 _mm256_storeu_ps(&op[0], vop0);
113 _mm256_storeu_ps(&op[8], vop8);
114 _mm256_storeu_ps(&op[16], vop16);
115 _mm256_storeu_ps(&op[24], vop24);
116 _mm256_storeu_ps(&op[32], vop32);
117 _mm256_storeu_ps(&op[40], vop40);
118 _mm256_storeu_ps(&op[48], vop48);
119 _mm256_storeu_ps(&op[56], vop56);
120 _mm256_storeu_ps(&op[64], vop64);
121 _mm256_storeu_ps(&op[72], vop72);
122 _mm256_storeu_ps(&op[80], vop80);
123 _mm256_storeu_ps(&op[88], vop88);
124 _mm256_storeu_ps(&op[96], vop96);
125 _mm256_storeu_ps(&op[104], vop104);
126 _mm256_storeu_ps(&op[112], vop112);
127 _mm256_storeu_ps(&op[120], vop120);
129 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
130 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
131 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
132 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
133 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
134 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
135 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
136 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
137 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
138 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
139 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
140 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
141 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
142 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
143 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
144 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
145 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
148 }
else if (block_size == 64) {
150 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
151 float* op = &out[rangeIndex * block_size];
152 __m256 vop0 = _mm256_setzero_ps();
153 __m256 vop8 = _mm256_setzero_ps();
154 __m256 vop16 = _mm256_setzero_ps();
155 __m256 vop24 = _mm256_setzero_ps();
156 __m256 vop32 = _mm256_setzero_ps();
157 __m256 vop40 = _mm256_setzero_ps();
158 __m256 vop48 = _mm256_setzero_ps();
159 __m256 vop56 = _mm256_setzero_ps();
160 if (dataInd + lengths[rangeIndex] > index_size) {
163 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
165 const int idx = indices[dataInd];
166 if (idx < 0 || idx >= data_size) {
171 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
173 __m256 vwgt = _mm256_set1_ps(wgt);
174 const float* ip = &input[idx * fused_block_size];
175 const int next_T0 = (dataInd < index_size - prefdist_T0)
176 ? (dataInd + prefdist_T0)
178 const int idx_pref_T0 = indices[next_T0];
179 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
182 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
183 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
185 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
186 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
188 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
190 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
191 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
193 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
195 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
196 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
198 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
200 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
201 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
204 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
205 _mm256_storeu_ps(&op[0], vop0);
206 _mm256_storeu_ps(&op[8], vop8);
207 _mm256_storeu_ps(&op[16], vop16);
208 _mm256_storeu_ps(&op[24], vop24);
209 _mm256_storeu_ps(&op[32], vop32);
210 _mm256_storeu_ps(&op[40], vop40);
211 _mm256_storeu_ps(&op[48], vop48);
212 _mm256_storeu_ps(&op[56], vop56);
214 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
215 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
216 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
217 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
218 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
219 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
220 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
221 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
222 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
225 }
else if (block_size == 32) {
227 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
228 float* op = &out[rangeIndex * block_size];
229 __m256 vop0 = _mm256_setzero_ps();
230 __m256 vop8 = _mm256_setzero_ps();
231 __m256 vop16 = _mm256_setzero_ps();
232 __m256 vop24 = _mm256_setzero_ps();
233 if (dataInd + lengths[rangeIndex] > index_size) {
236 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
238 const int idx = indices[dataInd];
239 if (idx < 0 || idx >= data_size) {
244 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
246 __m256 vwgt = _mm256_set1_ps(wgt);
247 const float* ip = &input[idx * fused_block_size];
248 const int next_T0 = (dataInd < index_size - prefdist_T0)
249 ? (dataInd + prefdist_T0)
251 const int idx_pref_T0 = indices[next_T0];
252 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
255 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
256 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
258 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
259 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
261 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
263 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
264 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
267 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
268 _mm256_storeu_ps(&op[0], vop0);
269 _mm256_storeu_ps(&op[8], vop8);
270 _mm256_storeu_ps(&op[16], vop16);
271 _mm256_storeu_ps(&op[24], vop24);
273 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
274 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
275 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
276 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
277 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
280 }
else if (block_size == 16) {
282 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
283 float* op = &out[rangeIndex * block_size];
284 __m256 vop0 = _mm256_setzero_ps();
285 __m256 vop8 = _mm256_setzero_ps();
286 if (dataInd + lengths[rangeIndex] > index_size) {
289 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
291 const int idx = indices[dataInd];
292 if (idx < 0 || idx >= data_size) {
297 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
299 __m256 vwgt = _mm256_set1_ps(wgt);
300 const float* ip = &input[idx * fused_block_size];
301 const int next_T0 = (dataInd < index_size - prefdist_T0)
302 ? (dataInd + prefdist_T0)
304 const int idx_pref_T0 = indices[next_T0];
305 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
308 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
309 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
311 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
312 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
315 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
316 _mm256_storeu_ps(&op[0], vop0);
317 _mm256_storeu_ps(&op[8], vop8);
319 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
320 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
321 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
326 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
327 float* op = &out[rangeIndex * block_size];
329 for (; j + 8 <= block_size; j += 8) {
330 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
332 for (; j < block_size; j++) {
335 if (dataInd + lengths[rangeIndex] > index_size) {
338 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
340 const int idx = indices[dataInd];
341 if (idx < 0 || idx >= data_size) {
346 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
348 __m256 vwgt = _mm256_set1_ps(wgt);
349 const float* ip = &input[idx * fused_block_size];
350 const int next_T0 = (dataInd < index_size - prefdist_T0)
351 ? (dataInd + prefdist_T0)
353 const int idx_pref_T0 = indices[next_T0];
354 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
357 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
359 for (; j + 8 <= block_size; j += 8) {
363 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
365 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
367 for (; j < block_size; j++) {
368 op[j] += wgt * ip[j];
371 if (normalize_by_lengths && lengths[rangeIndex]) {
372 float len_inv = 1.0f / lengths[rangeIndex];
373 __m256 vlen_inv = _mm256_set1_ps(len_inv);
375 for (; j + 8 <= block_size; j += 8) {
377 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
379 for (; j < block_size; j++) {
380 op[j] = len_inv * op[j];
385 return dataInd == index_size;
387 bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma(
388 const int64_t block_size,
389 const int64_t output_size,
390 const int64_t index_size,
391 const int64_t data_size,
395 const float* weights,
396 bool normalize_by_lengths,
398 return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<false>(
407 normalize_by_lengths,
410 bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma(
411 const int64_t block_size,
412 const int64_t output_size,
413 const int64_t index_size,
414 const int64_t data_size,
418 const float* weights,
419 bool normalize_by_lengths,
421 return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<true>(
430 normalize_by_lengths,
434 template <
bool IS_WEIGHT_POSITIONAL>
435 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
436 const int64_t block_size,
437 const int64_t output_size,
438 const int64_t index_size,
439 const int64_t data_size,
441 const int64_t* indices,
443 const float* weights,
444 bool normalize_by_lengths,
446 const int64_t prefdist_T0 = 16;
447 const int64_t fused_block_size = block_size + 2;
449 if (block_size == 128) {
451 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
452 float* op = &out[rangeIndex * block_size];
453 __m256 vop0 = _mm256_setzero_ps();
454 __m256 vop8 = _mm256_setzero_ps();
455 __m256 vop16 = _mm256_setzero_ps();
456 __m256 vop24 = _mm256_setzero_ps();
457 __m256 vop32 = _mm256_setzero_ps();
458 __m256 vop40 = _mm256_setzero_ps();
459 __m256 vop48 = _mm256_setzero_ps();
460 __m256 vop56 = _mm256_setzero_ps();
461 __m256 vop64 = _mm256_setzero_ps();
462 __m256 vop72 = _mm256_setzero_ps();
463 __m256 vop80 = _mm256_setzero_ps();
464 __m256 vop88 = _mm256_setzero_ps();
465 __m256 vop96 = _mm256_setzero_ps();
466 __m256 vop104 = _mm256_setzero_ps();
467 __m256 vop112 = _mm256_setzero_ps();
468 __m256 vop120 = _mm256_setzero_ps();
469 if (dataInd + lengths[rangeIndex] > index_size) {
472 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
474 const int64_t idx = indices[dataInd];
475 if (idx < 0 || idx >= data_size) {
480 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
482 __m256 vwgt = _mm256_set1_ps(wgt);
483 const float* ip = &input[idx * fused_block_size];
484 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
485 ? (dataInd + prefdist_T0)
487 const int64_t idx_pref_T0 = indices[next_T0];
488 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
491 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
492 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
494 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
495 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
497 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
499 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
500 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
502 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
504 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
505 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
507 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
509 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
510 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
512 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
514 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
515 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
517 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
519 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
520 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
522 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
524 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
525 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
527 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
529 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
530 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
533 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
534 _mm256_storeu_ps(&op[0], vop0);
535 _mm256_storeu_ps(&op[8], vop8);
536 _mm256_storeu_ps(&op[16], vop16);
537 _mm256_storeu_ps(&op[24], vop24);
538 _mm256_storeu_ps(&op[32], vop32);
539 _mm256_storeu_ps(&op[40], vop40);
540 _mm256_storeu_ps(&op[48], vop48);
541 _mm256_storeu_ps(&op[56], vop56);
542 _mm256_storeu_ps(&op[64], vop64);
543 _mm256_storeu_ps(&op[72], vop72);
544 _mm256_storeu_ps(&op[80], vop80);
545 _mm256_storeu_ps(&op[88], vop88);
546 _mm256_storeu_ps(&op[96], vop96);
547 _mm256_storeu_ps(&op[104], vop104);
548 _mm256_storeu_ps(&op[112], vop112);
549 _mm256_storeu_ps(&op[120], vop120);
551 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
552 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
553 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
554 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
555 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
556 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
557 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
558 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
559 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
560 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
561 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
562 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
563 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
564 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
565 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
566 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
567 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
570 }
else if (block_size == 64) {
572 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
573 float* op = &out[rangeIndex * block_size];
574 __m256 vop0 = _mm256_setzero_ps();
575 __m256 vop8 = _mm256_setzero_ps();
576 __m256 vop16 = _mm256_setzero_ps();
577 __m256 vop24 = _mm256_setzero_ps();
578 __m256 vop32 = _mm256_setzero_ps();
579 __m256 vop40 = _mm256_setzero_ps();
580 __m256 vop48 = _mm256_setzero_ps();
581 __m256 vop56 = _mm256_setzero_ps();
582 if (dataInd + lengths[rangeIndex] > index_size) {
585 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
587 const int64_t idx = indices[dataInd];
588 if (idx < 0 || idx >= data_size) {
593 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
595 __m256 vwgt = _mm256_set1_ps(wgt);
596 const float* ip = &input[idx * fused_block_size];
597 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
598 ? (dataInd + prefdist_T0)
600 const int64_t idx_pref_T0 = indices[next_T0];
601 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
604 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
605 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
607 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
608 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
610 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
612 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
613 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
615 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
617 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
618 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
620 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
622 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
623 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
626 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
627 _mm256_storeu_ps(&op[0], vop0);
628 _mm256_storeu_ps(&op[8], vop8);
629 _mm256_storeu_ps(&op[16], vop16);
630 _mm256_storeu_ps(&op[24], vop24);
631 _mm256_storeu_ps(&op[32], vop32);
632 _mm256_storeu_ps(&op[40], vop40);
633 _mm256_storeu_ps(&op[48], vop48);
634 _mm256_storeu_ps(&op[56], vop56);
636 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
637 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
638 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
639 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
640 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
641 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
642 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
643 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
644 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
647 }
else if (block_size == 32) {
649 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
650 float* op = &out[rangeIndex * block_size];
651 __m256 vop0 = _mm256_setzero_ps();
652 __m256 vop8 = _mm256_setzero_ps();
653 __m256 vop16 = _mm256_setzero_ps();
654 __m256 vop24 = _mm256_setzero_ps();
655 if (dataInd + lengths[rangeIndex] > index_size) {
658 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
660 const int64_t idx = indices[dataInd];
661 if (idx < 0 || idx >= data_size) {
666 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
668 __m256 vwgt = _mm256_set1_ps(wgt);
669 const float* ip = &input[idx * fused_block_size];
670 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
671 ? (dataInd + prefdist_T0)
673 const int64_t idx_pref_T0 = indices[next_T0];
674 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
677 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
678 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
680 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
681 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
683 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
685 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
686 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
689 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
690 _mm256_storeu_ps(&op[0], vop0);
691 _mm256_storeu_ps(&op[8], vop8);
692 _mm256_storeu_ps(&op[16], vop16);
693 _mm256_storeu_ps(&op[24], vop24);
695 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
696 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
697 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
698 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
699 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
702 }
else if (block_size == 16) {
704 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
705 float* op = &out[rangeIndex * block_size];
706 __m256 vop0 = _mm256_setzero_ps();
707 __m256 vop8 = _mm256_setzero_ps();
708 if (dataInd + lengths[rangeIndex] > index_size) {
711 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
713 const int64_t idx = indices[dataInd];
714 if (idx < 0 || idx >= data_size) {
719 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
721 __m256 vwgt = _mm256_set1_ps(wgt);
722 const float* ip = &input[idx * fused_block_size];
723 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
724 ? (dataInd + prefdist_T0)
726 const int64_t idx_pref_T0 = indices[next_T0];
727 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
730 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
731 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
733 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
734 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
737 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
738 _mm256_storeu_ps(&op[0], vop0);
739 _mm256_storeu_ps(&op[8], vop8);
741 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
742 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
743 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
748 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
749 float* op = &out[rangeIndex * block_size];
751 for (; j + 8 <= block_size; j += 8) {
752 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
754 for (; j < block_size; j++) {
757 if (dataInd + lengths[rangeIndex] > index_size) {
760 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
762 const int64_t idx = indices[dataInd];
763 if (idx < 0 || idx >= data_size) {
768 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
770 __m256 vwgt = _mm256_set1_ps(wgt);
771 const float* ip = &input[idx * fused_block_size];
772 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
773 ? (dataInd + prefdist_T0)
775 const int64_t idx_pref_T0 = indices[next_T0];
776 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
779 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
781 for (; j + 8 <= block_size; j += 8) {
785 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
787 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
789 for (; j < block_size; j++) {
790 op[j] += wgt * ip[j];
793 if (normalize_by_lengths && lengths[rangeIndex]) {
794 float len_inv = 1.0f / lengths[rangeIndex];
795 __m256 vlen_inv = _mm256_set1_ps(len_inv);
797 for (; j + 8 <= block_size; j += 8) {
799 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
801 for (; j < block_size; j++) {
802 op[j] = len_inv * op[j];
807 return dataInd == index_size;
809 bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma(
810 const int64_t block_size,
811 const int64_t output_size,
812 const int64_t index_size,
813 const int64_t data_size,
815 const int64_t* indices,
817 const float* weights,
818 bool normalize_by_lengths,
820 return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<false>(
829 normalize_by_lengths,
832 bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma(
833 const int64_t block_size,
834 const int64_t output_size,
835 const int64_t index_size,
836 const int64_t data_size,
838 const int64_t* indices,
840 const float* weights,
841 bool normalize_by_lengths,
843 return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<true>(
852 normalize_by_lengths,
856 template <
bool IS_WEIGHT_POSITIONAL>
857 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
858 const int64_t block_size,
859 const int64_t output_size,
860 const int64_t index_size,
861 const int64_t data_size,
865 const float* weights,
866 bool normalize_by_lengths,
868 const int prefdist_T0 = 16;
869 const int fused_block_size = block_size + 4;
871 if (block_size == 128) {
873 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
874 float* op = &out[rangeIndex * block_size];
875 __m256 vop0 = _mm256_setzero_ps();
876 __m256 vop8 = _mm256_setzero_ps();
877 __m256 vop16 = _mm256_setzero_ps();
878 __m256 vop24 = _mm256_setzero_ps();
879 __m256 vop32 = _mm256_setzero_ps();
880 __m256 vop40 = _mm256_setzero_ps();
881 __m256 vop48 = _mm256_setzero_ps();
882 __m256 vop56 = _mm256_setzero_ps();
883 __m256 vop64 = _mm256_setzero_ps();
884 __m256 vop72 = _mm256_setzero_ps();
885 __m256 vop80 = _mm256_setzero_ps();
886 __m256 vop88 = _mm256_setzero_ps();
887 __m256 vop96 = _mm256_setzero_ps();
888 __m256 vop104 = _mm256_setzero_ps();
889 __m256 vop112 = _mm256_setzero_ps();
890 __m256 vop120 = _mm256_setzero_ps();
891 if (dataInd + lengths[rangeIndex] > index_size) {
894 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
896 const int idx = indices[dataInd];
897 if (idx < 0 || idx >= data_size) {
902 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
904 __m256 vwgt = _mm256_set1_ps(wgt);
905 const at::Half* ip = &input[idx * fused_block_size];
906 const int next_T0 = (dataInd < index_size - prefdist_T0)
907 ? (dataInd + prefdist_T0)
909 const int idx_pref_T0 = indices[next_T0];
910 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
913 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
914 vop0 = _mm256_fmadd_ps(
917 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
920 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
921 vop8 = _mm256_fmadd_ps(
924 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
927 vop16 = _mm256_fmadd_ps(
930 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
933 vop24 = _mm256_fmadd_ps(
936 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
939 vop32 = _mm256_fmadd_ps(
942 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
945 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
946 vop40 = _mm256_fmadd_ps(
949 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
952 vop48 = _mm256_fmadd_ps(
955 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
958 vop56 = _mm256_fmadd_ps(
961 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
964 vop64 = _mm256_fmadd_ps(
967 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
970 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
971 vop72 = _mm256_fmadd_ps(
974 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
977 vop80 = _mm256_fmadd_ps(
980 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
983 vop88 = _mm256_fmadd_ps(
986 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
989 vop96 = _mm256_fmadd_ps(
992 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
995 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
996 vop104 = _mm256_fmadd_ps(
999 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1002 vop112 = _mm256_fmadd_ps(
1005 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1008 vop120 = _mm256_fmadd_ps(
1011 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1015 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1016 _mm256_storeu_ps(&op[0], vop0);
1017 _mm256_storeu_ps(&op[8], vop8);
1018 _mm256_storeu_ps(&op[16], vop16);
1019 _mm256_storeu_ps(&op[24], vop24);
1020 _mm256_storeu_ps(&op[32], vop32);
1021 _mm256_storeu_ps(&op[40], vop40);
1022 _mm256_storeu_ps(&op[48], vop48);
1023 _mm256_storeu_ps(&op[56], vop56);
1024 _mm256_storeu_ps(&op[64], vop64);
1025 _mm256_storeu_ps(&op[72], vop72);
1026 _mm256_storeu_ps(&op[80], vop80);
1027 _mm256_storeu_ps(&op[88], vop88);
1028 _mm256_storeu_ps(&op[96], vop96);
1029 _mm256_storeu_ps(&op[104], vop104);
1030 _mm256_storeu_ps(&op[112], vop112);
1031 _mm256_storeu_ps(&op[120], vop120);
1033 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1034 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1035 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1036 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1037 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1038 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1039 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1040 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1041 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1042 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1043 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1044 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1045 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1046 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1047 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1048 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1049 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1052 }
else if (block_size == 64) {
1054 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1055 float* op = &out[rangeIndex * block_size];
1056 __m256 vop0 = _mm256_setzero_ps();
1057 __m256 vop8 = _mm256_setzero_ps();
1058 __m256 vop16 = _mm256_setzero_ps();
1059 __m256 vop24 = _mm256_setzero_ps();
1060 __m256 vop32 = _mm256_setzero_ps();
1061 __m256 vop40 = _mm256_setzero_ps();
1062 __m256 vop48 = _mm256_setzero_ps();
1063 __m256 vop56 = _mm256_setzero_ps();
1064 if (dataInd + lengths[rangeIndex] > index_size) {
1067 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1069 const int idx = indices[dataInd];
1070 if (idx < 0 || idx >= data_size) {
1075 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1077 __m256 vwgt = _mm256_set1_ps(wgt);
1078 const at::Half* ip = &input[idx * fused_block_size];
1079 const int next_T0 = (dataInd < index_size - prefdist_T0)
1080 ? (dataInd + prefdist_T0)
1082 const int idx_pref_T0 = indices[next_T0];
1083 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1086 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1087 vop0 = _mm256_fmadd_ps(
1090 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1093 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1094 vop8 = _mm256_fmadd_ps(
1097 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1100 vop16 = _mm256_fmadd_ps(
1103 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1106 vop24 = _mm256_fmadd_ps(
1109 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1112 vop32 = _mm256_fmadd_ps(
1115 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1118 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1119 vop40 = _mm256_fmadd_ps(
1122 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1125 vop48 = _mm256_fmadd_ps(
1128 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1131 vop56 = _mm256_fmadd_ps(
1134 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1138 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1139 _mm256_storeu_ps(&op[0], vop0);
1140 _mm256_storeu_ps(&op[8], vop8);
1141 _mm256_storeu_ps(&op[16], vop16);
1142 _mm256_storeu_ps(&op[24], vop24);
1143 _mm256_storeu_ps(&op[32], vop32);
1144 _mm256_storeu_ps(&op[40], vop40);
1145 _mm256_storeu_ps(&op[48], vop48);
1146 _mm256_storeu_ps(&op[56], vop56);
1148 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1149 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1150 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1151 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1152 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1153 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1154 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1155 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1156 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1159 }
else if (block_size == 32) {
1161 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1162 float* op = &out[rangeIndex * block_size];
1163 __m256 vop0 = _mm256_setzero_ps();
1164 __m256 vop8 = _mm256_setzero_ps();
1165 __m256 vop16 = _mm256_setzero_ps();
1166 __m256 vop24 = _mm256_setzero_ps();
1167 if (dataInd + lengths[rangeIndex] > index_size) {
1170 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1172 const int idx = indices[dataInd];
1173 if (idx < 0 || idx >= data_size) {
1178 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1180 __m256 vwgt = _mm256_set1_ps(wgt);
1181 const at::Half* ip = &input[idx * fused_block_size];
1182 const int next_T0 = (dataInd < index_size - prefdist_T0)
1183 ? (dataInd + prefdist_T0)
1185 const int idx_pref_T0 = indices[next_T0];
1186 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1189 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1190 vop0 = _mm256_fmadd_ps(
1193 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1196 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1197 vop8 = _mm256_fmadd_ps(
1200 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1203 vop16 = _mm256_fmadd_ps(
1206 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1209 vop24 = _mm256_fmadd_ps(
1212 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1216 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1217 _mm256_storeu_ps(&op[0], vop0);
1218 _mm256_storeu_ps(&op[8], vop8);
1219 _mm256_storeu_ps(&op[16], vop16);
1220 _mm256_storeu_ps(&op[24], vop24);
1222 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1223 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1224 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1225 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1226 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1229 }
else if (block_size == 16) {
1231 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1232 float* op = &out[rangeIndex * block_size];
1233 __m256 vop0 = _mm256_setzero_ps();
1234 __m256 vop8 = _mm256_setzero_ps();
1235 if (dataInd + lengths[rangeIndex] > index_size) {
1238 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1240 const int idx = indices[dataInd];
1241 if (idx < 0 || idx >= data_size) {
1246 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1248 __m256 vwgt = _mm256_set1_ps(wgt);
1249 const at::Half* ip = &input[idx * fused_block_size];
1250 const int next_T0 = (dataInd < index_size - prefdist_T0)
1251 ? (dataInd + prefdist_T0)
1253 const int idx_pref_T0 = indices[next_T0];
1254 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1257 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1258 vop0 = _mm256_fmadd_ps(
1261 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1264 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1265 vop8 = _mm256_fmadd_ps(
1268 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1272 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1273 _mm256_storeu_ps(&op[0], vop0);
1274 _mm256_storeu_ps(&op[8], vop8);
1276 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1277 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1278 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1283 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1284 float* op = &out[rangeIndex * block_size];
1286 for (; j + 8 <= block_size; j += 8) {
1287 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1289 for (; j < block_size; j++) {
1292 if (dataInd + lengths[rangeIndex] > index_size) {
1295 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1297 const int idx = indices[dataInd];
1298 if (idx < 0 || idx >= data_size) {
1303 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1305 __m256 vwgt = _mm256_set1_ps(wgt);
1306 const at::Half* ip = &input[idx * fused_block_size];
1307 const int next_T0 = (dataInd < index_size - prefdist_T0)
1308 ? (dataInd + prefdist_T0)
1310 const int idx_pref_T0 = indices[next_T0];
1311 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1314 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1316 for (; j + 8 <= block_size; j += 8) {
1321 _mm256_cvtph_ps(_mm_loadu_si128(
1322 reinterpret_cast<const __m128i*>(&ip[j]))),
1323 _mm256_loadu_ps(&op[j])));
1325 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1328 for (; j < block_size; j++) {
1330 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1331 op[j] += wgt * ((
float*)(&vtmp2))[0];
1334 if (normalize_by_lengths && lengths[rangeIndex]) {
1335 float len_inv = 1.0f / lengths[rangeIndex];
1336 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1338 for (; j + 8 <= block_size; j += 8) {
1340 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1342 for (; j < block_size; j++) {
1343 op[j] = len_inv * op[j];
1348 return dataInd == index_size;
1350 bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_false__avx2_fma(
1351 const int64_t block_size,
1352 const int64_t output_size,
1353 const int64_t index_size,
1354 const int64_t data_size,
1358 const float* weights,
1359 bool normalize_by_lengths,
1361 return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma<false>(
1370 normalize_by_lengths,
1373 bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_true__avx2_fma(
1374 const int64_t block_size,
1375 const int64_t output_size,
1376 const int64_t index_size,
1377 const int64_t data_size,
1381 const float* weights,
1382 bool normalize_by_lengths,
1384 return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma<true>(
1393 normalize_by_lengths,
1397 template <
bool IS_WEIGHT_POSITIONAL>
1398 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
1399 const int64_t block_size,
1400 const int64_t output_size,
1401 const int64_t index_size,
1402 const int64_t data_size,
1404 const int64_t* indices,
1406 const float* weights,
1407 bool normalize_by_lengths,
1409 const int64_t prefdist_T0 = 16;
1410 const int64_t fused_block_size = block_size + 4;
1411 int64_t dataInd = 0;
1412 if (block_size == 128) {
1414 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1415 float* op = &out[rangeIndex * block_size];
1416 __m256 vop0 = _mm256_setzero_ps();
1417 __m256 vop8 = _mm256_setzero_ps();
1418 __m256 vop16 = _mm256_setzero_ps();
1419 __m256 vop24 = _mm256_setzero_ps();
1420 __m256 vop32 = _mm256_setzero_ps();
1421 __m256 vop40 = _mm256_setzero_ps();
1422 __m256 vop48 = _mm256_setzero_ps();
1423 __m256 vop56 = _mm256_setzero_ps();
1424 __m256 vop64 = _mm256_setzero_ps();
1425 __m256 vop72 = _mm256_setzero_ps();
1426 __m256 vop80 = _mm256_setzero_ps();
1427 __m256 vop88 = _mm256_setzero_ps();
1428 __m256 vop96 = _mm256_setzero_ps();
1429 __m256 vop104 = _mm256_setzero_ps();
1430 __m256 vop112 = _mm256_setzero_ps();
1431 __m256 vop120 = _mm256_setzero_ps();
1432 if (dataInd + lengths[rangeIndex] > index_size) {
1435 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1437 const int64_t idx = indices[dataInd];
1438 if (idx < 0 || idx >= data_size) {
1443 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1445 __m256 vwgt = _mm256_set1_ps(wgt);
1446 const at::Half* ip = &input[idx * fused_block_size];
1447 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1448 ? (dataInd + prefdist_T0)
1450 const int64_t idx_pref_T0 = indices[next_T0];
1451 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1454 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1455 vop0 = _mm256_fmadd_ps(
1458 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1461 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1462 vop8 = _mm256_fmadd_ps(
1465 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1468 vop16 = _mm256_fmadd_ps(
1471 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1474 vop24 = _mm256_fmadd_ps(
1477 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1480 vop32 = _mm256_fmadd_ps(
1483 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1486 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1487 vop40 = _mm256_fmadd_ps(
1490 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1493 vop48 = _mm256_fmadd_ps(
1496 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1499 vop56 = _mm256_fmadd_ps(
1502 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1505 vop64 = _mm256_fmadd_ps(
1508 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1511 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1512 vop72 = _mm256_fmadd_ps(
1515 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1518 vop80 = _mm256_fmadd_ps(
1521 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1524 vop88 = _mm256_fmadd_ps(
1527 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1530 vop96 = _mm256_fmadd_ps(
1533 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1536 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1537 vop104 = _mm256_fmadd_ps(
1540 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1543 vop112 = _mm256_fmadd_ps(
1546 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1549 vop120 = _mm256_fmadd_ps(
1552 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1556 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1557 _mm256_storeu_ps(&op[0], vop0);
1558 _mm256_storeu_ps(&op[8], vop8);
1559 _mm256_storeu_ps(&op[16], vop16);
1560 _mm256_storeu_ps(&op[24], vop24);
1561 _mm256_storeu_ps(&op[32], vop32);
1562 _mm256_storeu_ps(&op[40], vop40);
1563 _mm256_storeu_ps(&op[48], vop48);
1564 _mm256_storeu_ps(&op[56], vop56);
1565 _mm256_storeu_ps(&op[64], vop64);
1566 _mm256_storeu_ps(&op[72], vop72);
1567 _mm256_storeu_ps(&op[80], vop80);
1568 _mm256_storeu_ps(&op[88], vop88);
1569 _mm256_storeu_ps(&op[96], vop96);
1570 _mm256_storeu_ps(&op[104], vop104);
1571 _mm256_storeu_ps(&op[112], vop112);
1572 _mm256_storeu_ps(&op[120], vop120);
1574 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1575 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1576 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1577 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1578 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1579 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1580 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1581 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1582 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1583 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1584 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1585 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1586 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1587 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1588 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1589 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1590 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1593 }
else if (block_size == 64) {
1595 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1596 float* op = &out[rangeIndex * block_size];
1597 __m256 vop0 = _mm256_setzero_ps();
1598 __m256 vop8 = _mm256_setzero_ps();
1599 __m256 vop16 = _mm256_setzero_ps();
1600 __m256 vop24 = _mm256_setzero_ps();
1601 __m256 vop32 = _mm256_setzero_ps();
1602 __m256 vop40 = _mm256_setzero_ps();
1603 __m256 vop48 = _mm256_setzero_ps();
1604 __m256 vop56 = _mm256_setzero_ps();
1605 if (dataInd + lengths[rangeIndex] > index_size) {
1608 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1610 const int64_t idx = indices[dataInd];
1611 if (idx < 0 || idx >= data_size) {
1616 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1618 __m256 vwgt = _mm256_set1_ps(wgt);
1619 const at::Half* ip = &input[idx * fused_block_size];
1620 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1621 ? (dataInd + prefdist_T0)
1623 const int64_t idx_pref_T0 = indices[next_T0];
1624 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1627 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1628 vop0 = _mm256_fmadd_ps(
1631 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1634 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1635 vop8 = _mm256_fmadd_ps(
1638 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1641 vop16 = _mm256_fmadd_ps(
1644 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1647 vop24 = _mm256_fmadd_ps(
1650 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1653 vop32 = _mm256_fmadd_ps(
1656 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1659 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1660 vop40 = _mm256_fmadd_ps(
1663 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1666 vop48 = _mm256_fmadd_ps(
1669 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1672 vop56 = _mm256_fmadd_ps(
1675 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1679 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1680 _mm256_storeu_ps(&op[0], vop0);
1681 _mm256_storeu_ps(&op[8], vop8);
1682 _mm256_storeu_ps(&op[16], vop16);
1683 _mm256_storeu_ps(&op[24], vop24);
1684 _mm256_storeu_ps(&op[32], vop32);
1685 _mm256_storeu_ps(&op[40], vop40);
1686 _mm256_storeu_ps(&op[48], vop48);
1687 _mm256_storeu_ps(&op[56], vop56);
1689 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1690 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1691 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1692 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1693 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1694 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1695 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1696 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1697 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1700 }
else if (block_size == 32) {
1702 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1703 float* op = &out[rangeIndex * block_size];
1704 __m256 vop0 = _mm256_setzero_ps();
1705 __m256 vop8 = _mm256_setzero_ps();
1706 __m256 vop16 = _mm256_setzero_ps();
1707 __m256 vop24 = _mm256_setzero_ps();
1708 if (dataInd + lengths[rangeIndex] > index_size) {
1711 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1713 const int64_t idx = indices[dataInd];
1714 if (idx < 0 || idx >= data_size) {
1719 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1721 __m256 vwgt = _mm256_set1_ps(wgt);
1722 const at::Half* ip = &input[idx * fused_block_size];
1723 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1724 ? (dataInd + prefdist_T0)
1726 const int64_t idx_pref_T0 = indices[next_T0];
1727 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1730 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1731 vop0 = _mm256_fmadd_ps(
1734 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1737 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1738 vop8 = _mm256_fmadd_ps(
1741 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1744 vop16 = _mm256_fmadd_ps(
1747 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1750 vop24 = _mm256_fmadd_ps(
1753 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1757 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1758 _mm256_storeu_ps(&op[0], vop0);
1759 _mm256_storeu_ps(&op[8], vop8);
1760 _mm256_storeu_ps(&op[16], vop16);
1761 _mm256_storeu_ps(&op[24], vop24);
1763 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1764 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1765 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1766 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1767 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1770 }
else if (block_size == 16) {
1772 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1773 float* op = &out[rangeIndex * block_size];
1774 __m256 vop0 = _mm256_setzero_ps();
1775 __m256 vop8 = _mm256_setzero_ps();
1776 if (dataInd + lengths[rangeIndex] > index_size) {
1779 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1781 const int64_t idx = indices[dataInd];
1782 if (idx < 0 || idx >= data_size) {
1787 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1789 __m256 vwgt = _mm256_set1_ps(wgt);
1790 const at::Half* ip = &input[idx * fused_block_size];
1791 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1792 ? (dataInd + prefdist_T0)
1794 const int64_t idx_pref_T0 = indices[next_T0];
1795 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1798 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1799 vop0 = _mm256_fmadd_ps(
1802 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1805 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1806 vop8 = _mm256_fmadd_ps(
1809 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1813 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1814 _mm256_storeu_ps(&op[0], vop0);
1815 _mm256_storeu_ps(&op[8], vop8);
1817 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1818 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1819 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1824 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1825 float* op = &out[rangeIndex * block_size];
1827 for (; j + 8 <= block_size; j += 8) {
1828 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1830 for (; j < block_size; j++) {
1833 if (dataInd + lengths[rangeIndex] > index_size) {
1836 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1838 const int64_t idx = indices[dataInd];
1839 if (idx < 0 || idx >= data_size) {
1844 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1846 __m256 vwgt = _mm256_set1_ps(wgt);
1847 const at::Half* ip = &input[idx * fused_block_size];
1848 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1849 ? (dataInd + prefdist_T0)
1851 const int64_t idx_pref_T0 = indices[next_T0];
1852 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1855 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1857 for (; j + 8 <= block_size; j += 8) {
1862 _mm256_cvtph_ps(_mm_loadu_si128(
1863 reinterpret_cast<const __m128i*>(&ip[j]))),
1864 _mm256_loadu_ps(&op[j])));
1866 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1869 for (; j < block_size; j++) {
1871 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1872 op[j] += wgt * ((
float*)(&vtmp2))[0];
1875 if (normalize_by_lengths && lengths[rangeIndex]) {
1876 float len_inv = 1.0f / lengths[rangeIndex];
1877 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1879 for (; j + 8 <= block_size; j += 8) {
1881 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1883 for (; j < block_size; j++) {
1884 op[j] = len_inv * op[j];
1889 return dataInd == index_size;
1891 bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma(
1892 const int64_t block_size,
1893 const int64_t output_size,
1894 const int64_t index_size,
1895 const int64_t data_size,
1897 const int64_t* indices,
1899 const float* weights,
1900 bool normalize_by_lengths,
1902 return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma<false>(
1911 normalize_by_lengths,
1914 bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma(
1915 const int64_t block_size,
1916 const int64_t output_size,
1917 const int64_t index_size,
1918 const int64_t data_size,
1920 const int64_t* indices,
1922 const float* weights,
1923 bool normalize_by_lengths,
1925 return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma<true>(
1934 normalize_by_lengths,
1938 template <
bool IS_WEIGHT_POSITIONAL>
1939 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1940 const int64_t block_size,
1941 const int64_t output_size,
1942 const int64_t index_size,
1943 const int64_t data_size,
1944 const uint8_t* input,
1947 const float* weights,
1948 bool normalize_by_lengths,
1950 const int prefdist_T0 = 16;
1951 const int fused_block_size = block_size + 8;
1953 if (block_size == 128) {
1955 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1956 float* op = &out[rangeIndex * block_size];
1957 __m256 vop0 = _mm256_setzero_ps();
1958 __m256 vop8 = _mm256_setzero_ps();
1959 __m256 vop16 = _mm256_setzero_ps();
1960 __m256 vop24 = _mm256_setzero_ps();
1961 __m256 vop32 = _mm256_setzero_ps();
1962 __m256 vop40 = _mm256_setzero_ps();
1963 __m256 vop48 = _mm256_setzero_ps();
1964 __m256 vop56 = _mm256_setzero_ps();
1965 __m256 vop64 = _mm256_setzero_ps();
1966 __m256 vop72 = _mm256_setzero_ps();
1967 __m256 vop80 = _mm256_setzero_ps();
1968 __m256 vop88 = _mm256_setzero_ps();
1969 __m256 vop96 = _mm256_setzero_ps();
1970 __m256 vop104 = _mm256_setzero_ps();
1971 __m256 vop112 = _mm256_setzero_ps();
1972 __m256 vop120 = _mm256_setzero_ps();
1973 if (dataInd + lengths[rangeIndex] > index_size) {
1976 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1978 const int idx = indices[dataInd];
1979 if (idx < 0 || idx >= data_size) {
1985 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1987 const float* scale_bias =
reinterpret_cast<const float*
>(
1988 &input[idx * fused_block_size + block_size]);
1989 bio = wgt * scale_bias[1];
1990 wgt = wgt * scale_bias[0];
1991 __m256 vbio = _mm256_set1_ps(bio);
1992 __m256 vwgt = _mm256_set1_ps(wgt);
1993 const uint8_t* ip = &input[idx * fused_block_size];
1994 const int next_T0 = (dataInd < index_size - prefdist_T0)
1995 ? (dataInd + prefdist_T0)
1997 const int idx_pref_T0 = indices[next_T0];
1998 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2001 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2002 vop0 = _mm256_fmadd_ps(
2004 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2006 _mm256_add_ps(vop0, vbio));
2008 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2009 vop8 = _mm256_fmadd_ps(
2011 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2012 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2013 _mm256_add_ps(vop8, vbio));
2015 vop16 = _mm256_fmadd_ps(
2017 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2018 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2019 _mm256_add_ps(vop16, vbio));
2021 vop24 = _mm256_fmadd_ps(
2023 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2024 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2025 _mm256_add_ps(vop24, vbio));
2027 vop32 = _mm256_fmadd_ps(
2029 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2030 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2031 _mm256_add_ps(vop32, vbio));
2033 vop40 = _mm256_fmadd_ps(
2035 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2036 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2037 _mm256_add_ps(vop40, vbio));
2039 vop48 = _mm256_fmadd_ps(
2041 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2042 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2043 _mm256_add_ps(vop48, vbio));
2045 vop56 = _mm256_fmadd_ps(
2047 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2048 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2049 _mm256_add_ps(vop56, vbio));
2051 vop64 = _mm256_fmadd_ps(
2053 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2054 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2055 _mm256_add_ps(vop64, vbio));
2057 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2058 vop72 = _mm256_fmadd_ps(
2060 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2061 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2062 _mm256_add_ps(vop72, vbio));
2064 vop80 = _mm256_fmadd_ps(
2066 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2067 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2068 _mm256_add_ps(vop80, vbio));
2070 vop88 = _mm256_fmadd_ps(
2072 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2073 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2074 _mm256_add_ps(vop88, vbio));
2076 vop96 = _mm256_fmadd_ps(
2078 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2079 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2080 _mm256_add_ps(vop96, vbio));
2082 vop104 = _mm256_fmadd_ps(
2084 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2085 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2086 _mm256_add_ps(vop104, vbio));
2088 vop112 = _mm256_fmadd_ps(
2090 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2091 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2092 _mm256_add_ps(vop112, vbio));
2094 vop120 = _mm256_fmadd_ps(
2096 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2097 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2098 _mm256_add_ps(vop120, vbio));
2101 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2102 _mm256_storeu_ps(&op[0], vop0);
2103 _mm256_storeu_ps(&op[8], vop8);
2104 _mm256_storeu_ps(&op[16], vop16);
2105 _mm256_storeu_ps(&op[24], vop24);
2106 _mm256_storeu_ps(&op[32], vop32);
2107 _mm256_storeu_ps(&op[40], vop40);
2108 _mm256_storeu_ps(&op[48], vop48);
2109 _mm256_storeu_ps(&op[56], vop56);
2110 _mm256_storeu_ps(&op[64], vop64);
2111 _mm256_storeu_ps(&op[72], vop72);
2112 _mm256_storeu_ps(&op[80], vop80);
2113 _mm256_storeu_ps(&op[88], vop88);
2114 _mm256_storeu_ps(&op[96], vop96);
2115 _mm256_storeu_ps(&op[104], vop104);
2116 _mm256_storeu_ps(&op[112], vop112);
2117 _mm256_storeu_ps(&op[120], vop120);
2119 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2120 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2121 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2122 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2123 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2124 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2125 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2126 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2127 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2128 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2129 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2130 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2131 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2132 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2133 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2134 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2135 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2138 }
else if (block_size == 64) {
2140 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2141 float* op = &out[rangeIndex * block_size];
2142 __m256 vop0 = _mm256_setzero_ps();
2143 __m256 vop8 = _mm256_setzero_ps();
2144 __m256 vop16 = _mm256_setzero_ps();
2145 __m256 vop24 = _mm256_setzero_ps();
2146 __m256 vop32 = _mm256_setzero_ps();
2147 __m256 vop40 = _mm256_setzero_ps();
2148 __m256 vop48 = _mm256_setzero_ps();
2149 __m256 vop56 = _mm256_setzero_ps();
2150 if (dataInd + lengths[rangeIndex] > index_size) {
2153 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2155 const int idx = indices[dataInd];
2156 if (idx < 0 || idx >= data_size) {
2162 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2164 const float* scale_bias =
reinterpret_cast<const float*
>(
2165 &input[idx * fused_block_size + block_size]);
2166 bio = wgt * scale_bias[1];
2167 wgt = wgt * scale_bias[0];
2168 __m256 vbio = _mm256_set1_ps(bio);
2169 __m256 vwgt = _mm256_set1_ps(wgt);
2170 const uint8_t* ip = &input[idx * fused_block_size];
2171 const int next_T0 = (dataInd < index_size - prefdist_T0)
2172 ? (dataInd + prefdist_T0)
2174 const int idx_pref_T0 = indices[next_T0];
2175 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2178 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2179 vop0 = _mm256_fmadd_ps(
2181 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2182 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2183 _mm256_add_ps(vop0, vbio));
2185 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2186 vop8 = _mm256_fmadd_ps(
2188 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2189 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2190 _mm256_add_ps(vop8, vbio));
2192 vop16 = _mm256_fmadd_ps(
2194 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2195 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2196 _mm256_add_ps(vop16, vbio));
2198 vop24 = _mm256_fmadd_ps(
2200 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2201 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2202 _mm256_add_ps(vop24, vbio));
2204 vop32 = _mm256_fmadd_ps(
2206 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2207 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2208 _mm256_add_ps(vop32, vbio));
2210 vop40 = _mm256_fmadd_ps(
2212 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2213 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2214 _mm256_add_ps(vop40, vbio));
2216 vop48 = _mm256_fmadd_ps(
2218 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2219 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2220 _mm256_add_ps(vop48, vbio));
2222 vop56 = _mm256_fmadd_ps(
2224 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2225 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2226 _mm256_add_ps(vop56, vbio));
2229 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2230 _mm256_storeu_ps(&op[0], vop0);
2231 _mm256_storeu_ps(&op[8], vop8);
2232 _mm256_storeu_ps(&op[16], vop16);
2233 _mm256_storeu_ps(&op[24], vop24);
2234 _mm256_storeu_ps(&op[32], vop32);
2235 _mm256_storeu_ps(&op[40], vop40);
2236 _mm256_storeu_ps(&op[48], vop48);
2237 _mm256_storeu_ps(&op[56], vop56);
2239 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2240 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2241 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2242 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2243 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2244 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2245 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2246 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2247 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2250 }
else if (block_size == 32) {
2252 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2253 float* op = &out[rangeIndex * block_size];
2254 __m256 vop0 = _mm256_setzero_ps();
2255 __m256 vop8 = _mm256_setzero_ps();
2256 __m256 vop16 = _mm256_setzero_ps();
2257 __m256 vop24 = _mm256_setzero_ps();
2258 if (dataInd + lengths[rangeIndex] > index_size) {
2261 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2263 const int idx = indices[dataInd];
2264 if (idx < 0 || idx >= data_size) {
2270 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2272 const float* scale_bias =
reinterpret_cast<const float*
>(
2273 &input[idx * fused_block_size + block_size]);
2274 bio = wgt * scale_bias[1];
2275 wgt = wgt * scale_bias[0];
2276 __m256 vbio = _mm256_set1_ps(bio);
2277 __m256 vwgt = _mm256_set1_ps(wgt);
2278 const uint8_t* ip = &input[idx * fused_block_size];
2279 const int next_T0 = (dataInd < index_size - prefdist_T0)
2280 ? (dataInd + prefdist_T0)
2282 const int idx_pref_T0 = indices[next_T0];
2283 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2286 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2287 vop0 = _mm256_fmadd_ps(
2289 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2290 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2291 _mm256_add_ps(vop0, vbio));
2293 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2294 vop8 = _mm256_fmadd_ps(
2296 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2297 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2298 _mm256_add_ps(vop8, vbio));
2300 vop16 = _mm256_fmadd_ps(
2302 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2303 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2304 _mm256_add_ps(vop16, vbio));
2306 vop24 = _mm256_fmadd_ps(
2308 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2309 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2310 _mm256_add_ps(vop24, vbio));
2313 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2314 _mm256_storeu_ps(&op[0], vop0);
2315 _mm256_storeu_ps(&op[8], vop8);
2316 _mm256_storeu_ps(&op[16], vop16);
2317 _mm256_storeu_ps(&op[24], vop24);
2319 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2320 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2321 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2322 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2323 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2326 }
else if (block_size == 16) {
2328 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2329 float* op = &out[rangeIndex * block_size];
2330 __m256 vop0 = _mm256_setzero_ps();
2331 __m256 vop8 = _mm256_setzero_ps();
2332 if (dataInd + lengths[rangeIndex] > index_size) {
2335 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2337 const int idx = indices[dataInd];
2338 if (idx < 0 || idx >= data_size) {
2344 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2346 const float* scale_bias =
reinterpret_cast<const float*
>(
2347 &input[idx * fused_block_size + block_size]);
2348 bio = wgt * scale_bias[1];
2349 wgt = wgt * scale_bias[0];
2350 __m256 vbio = _mm256_set1_ps(bio);
2351 __m256 vwgt = _mm256_set1_ps(wgt);
2352 const uint8_t* ip = &input[idx * fused_block_size];
2353 const int next_T0 = (dataInd < index_size - prefdist_T0)
2354 ? (dataInd + prefdist_T0)
2356 const int idx_pref_T0 = indices[next_T0];
2357 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2360 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2361 vop0 = _mm256_fmadd_ps(
2363 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2364 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2365 _mm256_add_ps(vop0, vbio));
2367 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2368 vop8 = _mm256_fmadd_ps(
2370 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2371 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2372 _mm256_add_ps(vop8, vbio));
2375 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2376 _mm256_storeu_ps(&op[0], vop0);
2377 _mm256_storeu_ps(&op[8], vop8);
2379 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2380 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2381 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2386 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2387 float* op = &out[rangeIndex * block_size];
2389 for (; j + 8 <= block_size; j += 8) {
2390 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2392 for (; j < block_size; j++) {
2395 if (dataInd + lengths[rangeIndex] > index_size) {
2398 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2400 const int idx = indices[dataInd];
2401 if (idx < 0 || idx >= data_size) {
2407 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2409 const float* scale_bias =
reinterpret_cast<const float*
>(
2410 &input[idx * fused_block_size + block_size]);
2411 bio = wgt * scale_bias[1];
2412 wgt = wgt * scale_bias[0];
2413 __m256 vbio = _mm256_set1_ps(bio);
2414 __m256 vwgt = _mm256_set1_ps(wgt);
2415 const uint8_t* ip = &input[idx * fused_block_size];
2416 const int next_T0 = (dataInd < index_size - prefdist_T0)
2417 ? (dataInd + prefdist_T0)
2419 const int idx_pref_T0 = indices[next_T0];
2420 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2423 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2425 for (; j + 8 <= block_size; j += 8) {
2430 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2431 reinterpret_cast<const __m128i*>(&ip[j])))),
2432 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2434 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2436 for (; j < block_size; j++) {
2437 op[j] += wgt * ((float)ip[j]) + bio;
2440 if (normalize_by_lengths && lengths[rangeIndex]) {
2441 float len_inv = 1.0f / lengths[rangeIndex];
2442 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2444 for (; j + 8 <= block_size; j += 8) {
2446 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2448 for (; j < block_size; j++) {
2449 op[j] = len_inv * op[j];
2454 return dataInd == index_size;
2456 bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2457 const int64_t block_size,
2458 const int64_t output_size,
2459 const int64_t index_size,
2460 const int64_t data_size,
2461 const uint8_t* input,
2464 const float* weights,
2465 bool normalize_by_lengths,
2467 return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2476 normalize_by_lengths,
2479 bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2480 const int64_t block_size,
2481 const int64_t output_size,
2482 const int64_t index_size,
2483 const int64_t data_size,
2484 const uint8_t* input,
2487 const float* weights,
2488 bool normalize_by_lengths,
2490 return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2499 normalize_by_lengths,
2503 template <
bool IS_WEIGHT_POSITIONAL>
2504 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2505 const int64_t block_size,
2506 const int64_t output_size,
2507 const int64_t index_size,
2508 const int64_t data_size,
2509 const uint8_t* input,
2510 const int64_t* indices,
2512 const float* weights,
2513 bool normalize_by_lengths,
2515 const int64_t prefdist_T0 = 16;
2516 const int64_t fused_block_size = block_size + 8;
2517 int64_t dataInd = 0;
2518 if (block_size == 128) {
2520 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2521 float* op = &out[rangeIndex * block_size];
2522 __m256 vop0 = _mm256_setzero_ps();
2523 __m256 vop8 = _mm256_setzero_ps();
2524 __m256 vop16 = _mm256_setzero_ps();
2525 __m256 vop24 = _mm256_setzero_ps();
2526 __m256 vop32 = _mm256_setzero_ps();
2527 __m256 vop40 = _mm256_setzero_ps();
2528 __m256 vop48 = _mm256_setzero_ps();
2529 __m256 vop56 = _mm256_setzero_ps();
2530 __m256 vop64 = _mm256_setzero_ps();
2531 __m256 vop72 = _mm256_setzero_ps();
2532 __m256 vop80 = _mm256_setzero_ps();
2533 __m256 vop88 = _mm256_setzero_ps();
2534 __m256 vop96 = _mm256_setzero_ps();
2535 __m256 vop104 = _mm256_setzero_ps();
2536 __m256 vop112 = _mm256_setzero_ps();
2537 __m256 vop120 = _mm256_setzero_ps();
2538 if (dataInd + lengths[rangeIndex] > index_size) {
2541 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2543 const int64_t idx = indices[dataInd];
2544 if (idx < 0 || idx >= data_size) {
2550 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2552 const float* scale_bias =
reinterpret_cast<const float*
>(
2553 &input[idx * fused_block_size + block_size]);
2554 bio = wgt * scale_bias[1];
2555 wgt = wgt * scale_bias[0];
2556 __m256 vbio = _mm256_set1_ps(bio);
2557 __m256 vwgt = _mm256_set1_ps(wgt);
2558 const uint8_t* ip = &input[idx * fused_block_size];
2559 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2560 ? (dataInd + prefdist_T0)
2562 const int64_t idx_pref_T0 = indices[next_T0];
2563 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2566 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2567 vop0 = _mm256_fmadd_ps(
2569 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2570 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2571 _mm256_add_ps(vop0, vbio));
2573 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2574 vop8 = _mm256_fmadd_ps(
2576 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2577 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2578 _mm256_add_ps(vop8, vbio));
2580 vop16 = _mm256_fmadd_ps(
2582 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2583 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2584 _mm256_add_ps(vop16, vbio));
2586 vop24 = _mm256_fmadd_ps(
2588 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2589 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2590 _mm256_add_ps(vop24, vbio));
2592 vop32 = _mm256_fmadd_ps(
2594 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2595 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2596 _mm256_add_ps(vop32, vbio));
2598 vop40 = _mm256_fmadd_ps(
2600 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2601 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2602 _mm256_add_ps(vop40, vbio));
2604 vop48 = _mm256_fmadd_ps(
2606 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2607 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2608 _mm256_add_ps(vop48, vbio));
2610 vop56 = _mm256_fmadd_ps(
2612 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2613 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2614 _mm256_add_ps(vop56, vbio));
2616 vop64 = _mm256_fmadd_ps(
2618 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2619 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2620 _mm256_add_ps(vop64, vbio));
2622 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2623 vop72 = _mm256_fmadd_ps(
2625 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2626 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2627 _mm256_add_ps(vop72, vbio));
2629 vop80 = _mm256_fmadd_ps(
2631 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2632 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2633 _mm256_add_ps(vop80, vbio));
2635 vop88 = _mm256_fmadd_ps(
2637 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2638 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2639 _mm256_add_ps(vop88, vbio));
2641 vop96 = _mm256_fmadd_ps(
2643 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2644 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2645 _mm256_add_ps(vop96, vbio));
2647 vop104 = _mm256_fmadd_ps(
2649 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2650 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2651 _mm256_add_ps(vop104, vbio));
2653 vop112 = _mm256_fmadd_ps(
2655 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2656 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2657 _mm256_add_ps(vop112, vbio));
2659 vop120 = _mm256_fmadd_ps(
2661 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2662 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2663 _mm256_add_ps(vop120, vbio));
2666 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2667 _mm256_storeu_ps(&op[0], vop0);
2668 _mm256_storeu_ps(&op[8], vop8);
2669 _mm256_storeu_ps(&op[16], vop16);
2670 _mm256_storeu_ps(&op[24], vop24);
2671 _mm256_storeu_ps(&op[32], vop32);
2672 _mm256_storeu_ps(&op[40], vop40);
2673 _mm256_storeu_ps(&op[48], vop48);
2674 _mm256_storeu_ps(&op[56], vop56);
2675 _mm256_storeu_ps(&op[64], vop64);
2676 _mm256_storeu_ps(&op[72], vop72);
2677 _mm256_storeu_ps(&op[80], vop80);
2678 _mm256_storeu_ps(&op[88], vop88);
2679 _mm256_storeu_ps(&op[96], vop96);
2680 _mm256_storeu_ps(&op[104], vop104);
2681 _mm256_storeu_ps(&op[112], vop112);
2682 _mm256_storeu_ps(&op[120], vop120);
2684 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2685 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2686 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2687 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2688 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2689 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2690 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2691 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2692 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2693 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2694 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2695 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2696 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2697 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2698 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2699 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2700 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2703 }
else if (block_size == 64) {
2705 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2706 float* op = &out[rangeIndex * block_size];
2707 __m256 vop0 = _mm256_setzero_ps();
2708 __m256 vop8 = _mm256_setzero_ps();
2709 __m256 vop16 = _mm256_setzero_ps();
2710 __m256 vop24 = _mm256_setzero_ps();
2711 __m256 vop32 = _mm256_setzero_ps();
2712 __m256 vop40 = _mm256_setzero_ps();
2713 __m256 vop48 = _mm256_setzero_ps();
2714 __m256 vop56 = _mm256_setzero_ps();
2715 if (dataInd + lengths[rangeIndex] > index_size) {
2718 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2720 const int64_t idx = indices[dataInd];
2721 if (idx < 0 || idx >= data_size) {
2727 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2729 const float* scale_bias =
reinterpret_cast<const float*
>(
2730 &input[idx * fused_block_size + block_size]);
2731 bio = wgt * scale_bias[1];
2732 wgt = wgt * scale_bias[0];
2733 __m256 vbio = _mm256_set1_ps(bio);
2734 __m256 vwgt = _mm256_set1_ps(wgt);
2735 const uint8_t* ip = &input[idx * fused_block_size];
2736 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2737 ? (dataInd + prefdist_T0)
2739 const int64_t idx_pref_T0 = indices[next_T0];
2740 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2743 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2744 vop0 = _mm256_fmadd_ps(
2746 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2747 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2748 _mm256_add_ps(vop0, vbio));
2750 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2751 vop8 = _mm256_fmadd_ps(
2753 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2754 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2755 _mm256_add_ps(vop8, vbio));
2757 vop16 = _mm256_fmadd_ps(
2759 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2760 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2761 _mm256_add_ps(vop16, vbio));
2763 vop24 = _mm256_fmadd_ps(
2765 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2766 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2767 _mm256_add_ps(vop24, vbio));
2769 vop32 = _mm256_fmadd_ps(
2771 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2772 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2773 _mm256_add_ps(vop32, vbio));
2775 vop40 = _mm256_fmadd_ps(
2777 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2778 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2779 _mm256_add_ps(vop40, vbio));
2781 vop48 = _mm256_fmadd_ps(
2783 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2784 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2785 _mm256_add_ps(vop48, vbio));
2787 vop56 = _mm256_fmadd_ps(
2789 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2790 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2791 _mm256_add_ps(vop56, vbio));
2794 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2795 _mm256_storeu_ps(&op[0], vop0);
2796 _mm256_storeu_ps(&op[8], vop8);
2797 _mm256_storeu_ps(&op[16], vop16);
2798 _mm256_storeu_ps(&op[24], vop24);
2799 _mm256_storeu_ps(&op[32], vop32);
2800 _mm256_storeu_ps(&op[40], vop40);
2801 _mm256_storeu_ps(&op[48], vop48);
2802 _mm256_storeu_ps(&op[56], vop56);
2804 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2805 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2806 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2807 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2808 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2809 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2810 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2811 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2812 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2815 }
else if (block_size == 32) {
2817 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2818 float* op = &out[rangeIndex * block_size];
2819 __m256 vop0 = _mm256_setzero_ps();
2820 __m256 vop8 = _mm256_setzero_ps();
2821 __m256 vop16 = _mm256_setzero_ps();
2822 __m256 vop24 = _mm256_setzero_ps();
2823 if (dataInd + lengths[rangeIndex] > index_size) {
2826 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2828 const int64_t idx = indices[dataInd];
2829 if (idx < 0 || idx >= data_size) {
2835 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2837 const float* scale_bias =
reinterpret_cast<const float*
>(
2838 &input[idx * fused_block_size + block_size]);
2839 bio = wgt * scale_bias[1];
2840 wgt = wgt * scale_bias[0];
2841 __m256 vbio = _mm256_set1_ps(bio);
2842 __m256 vwgt = _mm256_set1_ps(wgt);
2843 const uint8_t* ip = &input[idx * fused_block_size];
2844 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2845 ? (dataInd + prefdist_T0)
2847 const int64_t idx_pref_T0 = indices[next_T0];
2848 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2851 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2852 vop0 = _mm256_fmadd_ps(
2854 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2855 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2856 _mm256_add_ps(vop0, vbio));
2858 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2859 vop8 = _mm256_fmadd_ps(
2861 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2862 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2863 _mm256_add_ps(vop8, vbio));
2865 vop16 = _mm256_fmadd_ps(
2867 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2868 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2869 _mm256_add_ps(vop16, vbio));
2871 vop24 = _mm256_fmadd_ps(
2873 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2874 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2875 _mm256_add_ps(vop24, vbio));
2878 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2879 _mm256_storeu_ps(&op[0], vop0);
2880 _mm256_storeu_ps(&op[8], vop8);
2881 _mm256_storeu_ps(&op[16], vop16);
2882 _mm256_storeu_ps(&op[24], vop24);
2884 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2885 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2886 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2887 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2888 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2891 }
else if (block_size == 16) {
2893 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2894 float* op = &out[rangeIndex * block_size];
2895 __m256 vop0 = _mm256_setzero_ps();
2896 __m256 vop8 = _mm256_setzero_ps();
2897 if (dataInd + lengths[rangeIndex] > index_size) {
2900 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2902 const int64_t idx = indices[dataInd];
2903 if (idx < 0 || idx >= data_size) {
2909 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2911 const float* scale_bias =
reinterpret_cast<const float*
>(
2912 &input[idx * fused_block_size + block_size]);
2913 bio = wgt * scale_bias[1];
2914 wgt = wgt * scale_bias[0];
2915 __m256 vbio = _mm256_set1_ps(bio);
2916 __m256 vwgt = _mm256_set1_ps(wgt);
2917 const uint8_t* ip = &input[idx * fused_block_size];
2918 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2919 ? (dataInd + prefdist_T0)
2921 const int64_t idx_pref_T0 = indices[next_T0];
2922 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2925 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2926 vop0 = _mm256_fmadd_ps(
2928 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2929 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2930 _mm256_add_ps(vop0, vbio));
2932 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2933 vop8 = _mm256_fmadd_ps(
2935 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2936 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2937 _mm256_add_ps(vop8, vbio));
2940 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2941 _mm256_storeu_ps(&op[0], vop0);
2942 _mm256_storeu_ps(&op[8], vop8);
2944 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2945 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2946 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2951 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2952 float* op = &out[rangeIndex * block_size];
2954 for (; j + 8 <= block_size; j += 8) {
2955 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2957 for (; j < block_size; j++) {
2960 if (dataInd + lengths[rangeIndex] > index_size) {
2963 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2965 const int64_t idx = indices[dataInd];
2966 if (idx < 0 || idx >= data_size) {
2972 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2974 const float* scale_bias =
reinterpret_cast<const float*
>(
2975 &input[idx * fused_block_size + block_size]);
2976 bio = wgt * scale_bias[1];
2977 wgt = wgt * scale_bias[0];
2978 __m256 vbio = _mm256_set1_ps(bio);
2979 __m256 vwgt = _mm256_set1_ps(wgt);
2980 const uint8_t* ip = &input[idx * fused_block_size];
2981 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2982 ? (dataInd + prefdist_T0)
2984 const int64_t idx_pref_T0 = indices[next_T0];
2985 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2988 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2990 for (; j + 8 <= block_size; j += 8) {
2995 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2996 reinterpret_cast<const __m128i*>(&ip[j])))),
2997 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2999 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3001 for (; j < block_size; j++) {
3002 op[j] += wgt * ((float)ip[j]) + bio;
3005 if (normalize_by_lengths && lengths[rangeIndex]) {
3006 float len_inv = 1.0f / lengths[rangeIndex];
3007 __m256 vlen_inv = _mm256_set1_ps(len_inv);
3009 for (; j + 8 <= block_size; j += 8) {
3011 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3013 for (; j < block_size; j++) {
3014 op[j] = len_inv * op[j];
3019 return dataInd == index_size;
3021 bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
3022 const int64_t block_size,
3023 const int64_t output_size,
3024 const int64_t index_size,
3025 const int64_t data_size,
3026 const uint8_t* input,
3027 const int64_t* indices,
3029 const float* weights,
3030 bool normalize_by_lengths,
3032 return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3041 normalize_by_lengths,
3044 bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3045 const int64_t block_size,
3046 const int64_t output_size,
3047 const int64_t index_size,
3048 const int64_t data_size,
3049 const uint8_t* input,
3050 const int64_t* indices,
3052 const float* weights,
3053 bool normalize_by_lengths,
3055 return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3064 normalize_by_lengths,
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...