8 #include <c10/util/Half.h> 12 template <
bool IS_WEIGHT_POSITIONAL>
13 static bool EmbeddingLookup_int32_t_float_float__avx2_fma(
14 const int64_t block_size,
15 const int64_t output_size,
16 const int64_t index_size,
17 const int64_t data_size,
22 const float* scale_bias,
23 bool normalize_by_lengths,
25 const int prefdist_T0 = 16;
26 const int fused_block_size = block_size + 0;
28 if (block_size == 128) {
30 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
31 float* op = &out[rangeIndex * block_size];
32 __m256 vop0 = _mm256_setzero_ps();
33 __m256 vop8 = _mm256_setzero_ps();
34 __m256 vop16 = _mm256_setzero_ps();
35 __m256 vop24 = _mm256_setzero_ps();
36 __m256 vop32 = _mm256_setzero_ps();
37 __m256 vop40 = _mm256_setzero_ps();
38 __m256 vop48 = _mm256_setzero_ps();
39 __m256 vop56 = _mm256_setzero_ps();
40 __m256 vop64 = _mm256_setzero_ps();
41 __m256 vop72 = _mm256_setzero_ps();
42 __m256 vop80 = _mm256_setzero_ps();
43 __m256 vop88 = _mm256_setzero_ps();
44 __m256 vop96 = _mm256_setzero_ps();
45 __m256 vop104 = _mm256_setzero_ps();
46 __m256 vop112 = _mm256_setzero_ps();
47 __m256 vop120 = _mm256_setzero_ps();
48 if (dataInd + lengths[rangeIndex] > index_size) {
51 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
53 const int idx = indices[dataInd];
54 if (idx < 0 || idx >= data_size) {
59 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
61 __m256 vwgt = _mm256_set1_ps(wgt);
62 const float* ip = &input[idx * fused_block_size];
63 const int next_T0 = (dataInd < index_size - prefdist_T0)
64 ? (dataInd + prefdist_T0)
66 const int idx_pref_T0 = indices[next_T0];
67 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
70 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
71 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
73 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
74 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
76 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
78 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
79 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
81 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
83 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
84 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
86 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
88 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
89 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
91 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
93 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
94 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
96 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
98 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
99 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
101 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
103 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
104 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
106 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
108 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
109 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
112 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
113 _mm256_storeu_ps(&op[0], vop0);
114 _mm256_storeu_ps(&op[8], vop8);
115 _mm256_storeu_ps(&op[16], vop16);
116 _mm256_storeu_ps(&op[24], vop24);
117 _mm256_storeu_ps(&op[32], vop32);
118 _mm256_storeu_ps(&op[40], vop40);
119 _mm256_storeu_ps(&op[48], vop48);
120 _mm256_storeu_ps(&op[56], vop56);
121 _mm256_storeu_ps(&op[64], vop64);
122 _mm256_storeu_ps(&op[72], vop72);
123 _mm256_storeu_ps(&op[80], vop80);
124 _mm256_storeu_ps(&op[88], vop88);
125 _mm256_storeu_ps(&op[96], vop96);
126 _mm256_storeu_ps(&op[104], vop104);
127 _mm256_storeu_ps(&op[112], vop112);
128 _mm256_storeu_ps(&op[120], vop120);
130 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
131 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
132 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
133 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
134 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
135 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
136 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
137 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
138 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
139 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
140 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
141 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
142 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
143 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
144 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
145 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
146 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
149 }
else if (block_size == 64) {
151 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
152 float* op = &out[rangeIndex * block_size];
153 __m256 vop0 = _mm256_setzero_ps();
154 __m256 vop8 = _mm256_setzero_ps();
155 __m256 vop16 = _mm256_setzero_ps();
156 __m256 vop24 = _mm256_setzero_ps();
157 __m256 vop32 = _mm256_setzero_ps();
158 __m256 vop40 = _mm256_setzero_ps();
159 __m256 vop48 = _mm256_setzero_ps();
160 __m256 vop56 = _mm256_setzero_ps();
161 if (dataInd + lengths[rangeIndex] > index_size) {
164 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
166 const int idx = indices[dataInd];
167 if (idx < 0 || idx >= data_size) {
172 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
174 __m256 vwgt = _mm256_set1_ps(wgt);
175 const float* ip = &input[idx * fused_block_size];
176 const int next_T0 = (dataInd < index_size - prefdist_T0)
177 ? (dataInd + prefdist_T0)
179 const int idx_pref_T0 = indices[next_T0];
180 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
183 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
184 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
186 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
187 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
189 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
191 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
192 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
194 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
196 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
197 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
199 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
201 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
202 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
205 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
206 _mm256_storeu_ps(&op[0], vop0);
207 _mm256_storeu_ps(&op[8], vop8);
208 _mm256_storeu_ps(&op[16], vop16);
209 _mm256_storeu_ps(&op[24], vop24);
210 _mm256_storeu_ps(&op[32], vop32);
211 _mm256_storeu_ps(&op[40], vop40);
212 _mm256_storeu_ps(&op[48], vop48);
213 _mm256_storeu_ps(&op[56], vop56);
215 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
216 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
217 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
218 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
219 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
220 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
221 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
222 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
223 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
226 }
else if (block_size == 32) {
228 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
229 float* op = &out[rangeIndex * block_size];
230 __m256 vop0 = _mm256_setzero_ps();
231 __m256 vop8 = _mm256_setzero_ps();
232 __m256 vop16 = _mm256_setzero_ps();
233 __m256 vop24 = _mm256_setzero_ps();
234 if (dataInd + lengths[rangeIndex] > index_size) {
237 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
239 const int idx = indices[dataInd];
240 if (idx < 0 || idx >= data_size) {
245 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
247 __m256 vwgt = _mm256_set1_ps(wgt);
248 const float* ip = &input[idx * fused_block_size];
249 const int next_T0 = (dataInd < index_size - prefdist_T0)
250 ? (dataInd + prefdist_T0)
252 const int idx_pref_T0 = indices[next_T0];
253 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
256 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
257 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
259 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
260 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
262 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
264 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
265 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
268 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
269 _mm256_storeu_ps(&op[0], vop0);
270 _mm256_storeu_ps(&op[8], vop8);
271 _mm256_storeu_ps(&op[16], vop16);
272 _mm256_storeu_ps(&op[24], vop24);
274 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
275 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
276 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
277 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
278 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
281 }
else if (block_size == 16) {
283 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
284 float* op = &out[rangeIndex * block_size];
285 __m256 vop0 = _mm256_setzero_ps();
286 __m256 vop8 = _mm256_setzero_ps();
287 if (dataInd + lengths[rangeIndex] > index_size) {
290 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
292 const int idx = indices[dataInd];
293 if (idx < 0 || idx >= data_size) {
298 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
300 __m256 vwgt = _mm256_set1_ps(wgt);
301 const float* ip = &input[idx * fused_block_size];
302 const int next_T0 = (dataInd < index_size - prefdist_T0)
303 ? (dataInd + prefdist_T0)
305 const int idx_pref_T0 = indices[next_T0];
306 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
309 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
310 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
312 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
313 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
316 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
317 _mm256_storeu_ps(&op[0], vop0);
318 _mm256_storeu_ps(&op[8], vop8);
320 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
321 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
322 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
327 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
328 float* op = &out[rangeIndex * block_size];
330 for (; j + 8 <= block_size; j += 8) {
331 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
333 for (; j < block_size; j++) {
336 if (dataInd + lengths[rangeIndex] > index_size) {
339 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
341 const int idx = indices[dataInd];
342 if (idx < 0 || idx >= data_size) {
347 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
349 __m256 vwgt = _mm256_set1_ps(wgt);
350 const float* ip = &input[idx * fused_block_size];
351 const int next_T0 = (dataInd < index_size - prefdist_T0)
352 ? (dataInd + prefdist_T0)
354 const int idx_pref_T0 = indices[next_T0];
355 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
358 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
360 for (; j + 8 <= block_size; j += 8) {
364 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
366 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
368 for (; j < block_size; j++) {
369 op[j] += wgt * ip[j];
372 if (normalize_by_lengths && lengths[rangeIndex]) {
373 float len_inv = 1.0f / lengths[rangeIndex];
374 __m256 vlen_inv = _mm256_set1_ps(len_inv);
376 for (; j + 8 <= block_size; j += 8) {
378 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
380 for (; j < block_size; j++) {
381 op[j] = len_inv * op[j];
386 return dataInd == index_size;
388 bool EmbeddingLookup_int32_t_float_float_false__avx2_fma(
389 const int64_t block_size,
390 const int64_t output_size,
391 const int64_t index_size,
392 const int64_t data_size,
396 const float* weights,
397 const float* scale_bias,
398 bool normalize_by_lengths,
400 return EmbeddingLookup_int32_t_float_float__avx2_fma<false>(
410 normalize_by_lengths,
413 bool EmbeddingLookup_int32_t_float_float_true__avx2_fma(
414 const int64_t block_size,
415 const int64_t output_size,
416 const int64_t index_size,
417 const int64_t data_size,
421 const float* weights,
422 const float* scale_bias,
423 bool normalize_by_lengths,
425 return EmbeddingLookup_int32_t_float_float__avx2_fma<true>(
435 normalize_by_lengths,
439 template <
bool IS_WEIGHT_POSITIONAL>
440 static bool EmbeddingLookup_int64_t_float_float__avx2_fma(
441 const int64_t block_size,
442 const int64_t output_size,
443 const int64_t index_size,
444 const int64_t data_size,
446 const int64_t* indices,
448 const float* weights,
449 const float* scale_bias,
450 bool normalize_by_lengths,
452 const int64_t prefdist_T0 = 16;
453 const int64_t fused_block_size = block_size + 0;
455 if (block_size == 128) {
457 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
458 float* op = &out[rangeIndex * block_size];
459 __m256 vop0 = _mm256_setzero_ps();
460 __m256 vop8 = _mm256_setzero_ps();
461 __m256 vop16 = _mm256_setzero_ps();
462 __m256 vop24 = _mm256_setzero_ps();
463 __m256 vop32 = _mm256_setzero_ps();
464 __m256 vop40 = _mm256_setzero_ps();
465 __m256 vop48 = _mm256_setzero_ps();
466 __m256 vop56 = _mm256_setzero_ps();
467 __m256 vop64 = _mm256_setzero_ps();
468 __m256 vop72 = _mm256_setzero_ps();
469 __m256 vop80 = _mm256_setzero_ps();
470 __m256 vop88 = _mm256_setzero_ps();
471 __m256 vop96 = _mm256_setzero_ps();
472 __m256 vop104 = _mm256_setzero_ps();
473 __m256 vop112 = _mm256_setzero_ps();
474 __m256 vop120 = _mm256_setzero_ps();
475 if (dataInd + lengths[rangeIndex] > index_size) {
478 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
480 const int64_t idx = indices[dataInd];
481 if (idx < 0 || idx >= data_size) {
486 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
488 __m256 vwgt = _mm256_set1_ps(wgt);
489 const float* ip = &input[idx * fused_block_size];
490 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
491 ? (dataInd + prefdist_T0)
493 const int64_t idx_pref_T0 = indices[next_T0];
494 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
497 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
498 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
500 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
501 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
503 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
505 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
506 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
508 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
510 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
511 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
513 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
515 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
516 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
518 vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
520 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
521 vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
523 vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
525 reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
526 vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
528 vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
530 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
531 vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
533 vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
535 reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
536 vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
539 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
540 _mm256_storeu_ps(&op[0], vop0);
541 _mm256_storeu_ps(&op[8], vop8);
542 _mm256_storeu_ps(&op[16], vop16);
543 _mm256_storeu_ps(&op[24], vop24);
544 _mm256_storeu_ps(&op[32], vop32);
545 _mm256_storeu_ps(&op[40], vop40);
546 _mm256_storeu_ps(&op[48], vop48);
547 _mm256_storeu_ps(&op[56], vop56);
548 _mm256_storeu_ps(&op[64], vop64);
549 _mm256_storeu_ps(&op[72], vop72);
550 _mm256_storeu_ps(&op[80], vop80);
551 _mm256_storeu_ps(&op[88], vop88);
552 _mm256_storeu_ps(&op[96], vop96);
553 _mm256_storeu_ps(&op[104], vop104);
554 _mm256_storeu_ps(&op[112], vop112);
555 _mm256_storeu_ps(&op[120], vop120);
557 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
558 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
559 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
560 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
561 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
562 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
563 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
564 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
565 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
566 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
567 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
568 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
569 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
570 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
571 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
572 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
573 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
576 }
else if (block_size == 64) {
578 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
579 float* op = &out[rangeIndex * block_size];
580 __m256 vop0 = _mm256_setzero_ps();
581 __m256 vop8 = _mm256_setzero_ps();
582 __m256 vop16 = _mm256_setzero_ps();
583 __m256 vop24 = _mm256_setzero_ps();
584 __m256 vop32 = _mm256_setzero_ps();
585 __m256 vop40 = _mm256_setzero_ps();
586 __m256 vop48 = _mm256_setzero_ps();
587 __m256 vop56 = _mm256_setzero_ps();
588 if (dataInd + lengths[rangeIndex] > index_size) {
591 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
593 const int64_t idx = indices[dataInd];
594 if (idx < 0 || idx >= data_size) {
599 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
601 __m256 vwgt = _mm256_set1_ps(wgt);
602 const float* ip = &input[idx * fused_block_size];
603 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
604 ? (dataInd + prefdist_T0)
606 const int64_t idx_pref_T0 = indices[next_T0];
607 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
610 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
611 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
613 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
614 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
616 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
618 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
619 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
621 vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
623 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
624 vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
626 vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
628 reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
629 vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
632 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
633 _mm256_storeu_ps(&op[0], vop0);
634 _mm256_storeu_ps(&op[8], vop8);
635 _mm256_storeu_ps(&op[16], vop16);
636 _mm256_storeu_ps(&op[24], vop24);
637 _mm256_storeu_ps(&op[32], vop32);
638 _mm256_storeu_ps(&op[40], vop40);
639 _mm256_storeu_ps(&op[48], vop48);
640 _mm256_storeu_ps(&op[56], vop56);
642 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
643 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
644 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
645 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
646 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
647 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
648 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
649 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
650 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
653 }
else if (block_size == 32) {
655 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
656 float* op = &out[rangeIndex * block_size];
657 __m256 vop0 = _mm256_setzero_ps();
658 __m256 vop8 = _mm256_setzero_ps();
659 __m256 vop16 = _mm256_setzero_ps();
660 __m256 vop24 = _mm256_setzero_ps();
661 if (dataInd + lengths[rangeIndex] > index_size) {
664 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
666 const int64_t idx = indices[dataInd];
667 if (idx < 0 || idx >= data_size) {
672 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
674 __m256 vwgt = _mm256_set1_ps(wgt);
675 const float* ip = &input[idx * fused_block_size];
676 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
677 ? (dataInd + prefdist_T0)
679 const int64_t idx_pref_T0 = indices[next_T0];
680 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
683 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
684 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
686 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
687 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
689 vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
691 reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
692 vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
695 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
696 _mm256_storeu_ps(&op[0], vop0);
697 _mm256_storeu_ps(&op[8], vop8);
698 _mm256_storeu_ps(&op[16], vop16);
699 _mm256_storeu_ps(&op[24], vop24);
701 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
702 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
703 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
704 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
705 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
708 }
else if (block_size == 16) {
710 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
711 float* op = &out[rangeIndex * block_size];
712 __m256 vop0 = _mm256_setzero_ps();
713 __m256 vop8 = _mm256_setzero_ps();
714 if (dataInd + lengths[rangeIndex] > index_size) {
717 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
719 const int64_t idx = indices[dataInd];
720 if (idx < 0 || idx >= data_size) {
725 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
727 __m256 vwgt = _mm256_set1_ps(wgt);
728 const float* ip = &input[idx * fused_block_size];
729 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
730 ? (dataInd + prefdist_T0)
732 const int64_t idx_pref_T0 = indices[next_T0];
733 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
736 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
737 vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
739 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
740 vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
743 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
744 _mm256_storeu_ps(&op[0], vop0);
745 _mm256_storeu_ps(&op[8], vop8);
747 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
748 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
749 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
754 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
755 float* op = &out[rangeIndex * block_size];
757 for (; j + 8 <= block_size; j += 8) {
758 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
760 for (; j < block_size; j++) {
763 if (dataInd + lengths[rangeIndex] > index_size) {
766 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
768 const int64_t idx = indices[dataInd];
769 if (idx < 0 || idx >= data_size) {
774 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
776 __m256 vwgt = _mm256_set1_ps(wgt);
777 const float* ip = &input[idx * fused_block_size];
778 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
779 ? (dataInd + prefdist_T0)
781 const int64_t idx_pref_T0 = indices[next_T0];
782 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
785 const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
787 for (; j + 8 <= block_size; j += 8) {
791 vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
793 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
795 for (; j < block_size; j++) {
796 op[j] += wgt * ip[j];
799 if (normalize_by_lengths && lengths[rangeIndex]) {
800 float len_inv = 1.0f / lengths[rangeIndex];
801 __m256 vlen_inv = _mm256_set1_ps(len_inv);
803 for (; j + 8 <= block_size; j += 8) {
805 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
807 for (; j < block_size; j++) {
808 op[j] = len_inv * op[j];
813 return dataInd == index_size;
815 bool EmbeddingLookup_int64_t_float_float_false__avx2_fma(
816 const int64_t block_size,
817 const int64_t output_size,
818 const int64_t index_size,
819 const int64_t data_size,
821 const int64_t* indices,
823 const float* weights,
824 const float* scale_bias,
825 bool normalize_by_lengths,
827 return EmbeddingLookup_int64_t_float_float__avx2_fma<false>(
837 normalize_by_lengths,
840 bool EmbeddingLookup_int64_t_float_float_true__avx2_fma(
841 const int64_t block_size,
842 const int64_t output_size,
843 const int64_t index_size,
844 const int64_t data_size,
846 const int64_t* indices,
848 const float* weights,
849 const float* scale_bias,
850 bool normalize_by_lengths,
852 return EmbeddingLookup_int64_t_float_float__avx2_fma<true>(
862 normalize_by_lengths,
866 template <
bool IS_WEIGHT_POSITIONAL>
867 static bool EmbeddingLookup_int32_t_half_float__avx2_fma(
868 const int64_t block_size,
869 const int64_t output_size,
870 const int64_t index_size,
871 const int64_t data_size,
875 const float* weights,
876 const float* scale_bias,
877 bool normalize_by_lengths,
879 const int prefdist_T0 = 16;
880 const int fused_block_size = block_size + 0;
882 if (block_size == 128) {
884 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
885 float* op = &out[rangeIndex * block_size];
886 __m256 vop0 = _mm256_setzero_ps();
887 __m256 vop8 = _mm256_setzero_ps();
888 __m256 vop16 = _mm256_setzero_ps();
889 __m256 vop24 = _mm256_setzero_ps();
890 __m256 vop32 = _mm256_setzero_ps();
891 __m256 vop40 = _mm256_setzero_ps();
892 __m256 vop48 = _mm256_setzero_ps();
893 __m256 vop56 = _mm256_setzero_ps();
894 __m256 vop64 = _mm256_setzero_ps();
895 __m256 vop72 = _mm256_setzero_ps();
896 __m256 vop80 = _mm256_setzero_ps();
897 __m256 vop88 = _mm256_setzero_ps();
898 __m256 vop96 = _mm256_setzero_ps();
899 __m256 vop104 = _mm256_setzero_ps();
900 __m256 vop112 = _mm256_setzero_ps();
901 __m256 vop120 = _mm256_setzero_ps();
902 if (dataInd + lengths[rangeIndex] > index_size) {
905 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
907 const int idx = indices[dataInd];
908 if (idx < 0 || idx >= data_size) {
913 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
915 __m256 vwgt = _mm256_set1_ps(wgt);
916 const at::Half* ip = &input[idx * fused_block_size];
917 const int next_T0 = (dataInd < index_size - prefdist_T0)
918 ? (dataInd + prefdist_T0)
920 const int idx_pref_T0 = indices[next_T0];
921 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
924 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
925 vop0 = _mm256_fmadd_ps(
928 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
931 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
932 vop8 = _mm256_fmadd_ps(
935 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
938 vop16 = _mm256_fmadd_ps(
941 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
944 vop24 = _mm256_fmadd_ps(
947 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
950 vop32 = _mm256_fmadd_ps(
953 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
956 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
957 vop40 = _mm256_fmadd_ps(
960 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
963 vop48 = _mm256_fmadd_ps(
966 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
969 vop56 = _mm256_fmadd_ps(
972 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
975 vop64 = _mm256_fmadd_ps(
978 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
981 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
982 vop72 = _mm256_fmadd_ps(
985 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
988 vop80 = _mm256_fmadd_ps(
991 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
994 vop88 = _mm256_fmadd_ps(
997 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1000 vop96 = _mm256_fmadd_ps(
1003 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1006 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1007 vop104 = _mm256_fmadd_ps(
1010 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1013 vop112 = _mm256_fmadd_ps(
1016 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1019 vop120 = _mm256_fmadd_ps(
1022 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1026 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1027 _mm256_storeu_ps(&op[0], vop0);
1028 _mm256_storeu_ps(&op[8], vop8);
1029 _mm256_storeu_ps(&op[16], vop16);
1030 _mm256_storeu_ps(&op[24], vop24);
1031 _mm256_storeu_ps(&op[32], vop32);
1032 _mm256_storeu_ps(&op[40], vop40);
1033 _mm256_storeu_ps(&op[48], vop48);
1034 _mm256_storeu_ps(&op[56], vop56);
1035 _mm256_storeu_ps(&op[64], vop64);
1036 _mm256_storeu_ps(&op[72], vop72);
1037 _mm256_storeu_ps(&op[80], vop80);
1038 _mm256_storeu_ps(&op[88], vop88);
1039 _mm256_storeu_ps(&op[96], vop96);
1040 _mm256_storeu_ps(&op[104], vop104);
1041 _mm256_storeu_ps(&op[112], vop112);
1042 _mm256_storeu_ps(&op[120], vop120);
1044 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1045 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1046 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1047 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1048 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1049 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1050 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1051 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1052 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1053 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1054 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1055 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1056 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1057 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1058 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1059 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1060 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1063 }
else if (block_size == 64) {
1065 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1066 float* op = &out[rangeIndex * block_size];
1067 __m256 vop0 = _mm256_setzero_ps();
1068 __m256 vop8 = _mm256_setzero_ps();
1069 __m256 vop16 = _mm256_setzero_ps();
1070 __m256 vop24 = _mm256_setzero_ps();
1071 __m256 vop32 = _mm256_setzero_ps();
1072 __m256 vop40 = _mm256_setzero_ps();
1073 __m256 vop48 = _mm256_setzero_ps();
1074 __m256 vop56 = _mm256_setzero_ps();
1075 if (dataInd + lengths[rangeIndex] > index_size) {
1078 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1080 const int idx = indices[dataInd];
1081 if (idx < 0 || idx >= data_size) {
1086 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1088 __m256 vwgt = _mm256_set1_ps(wgt);
1089 const at::Half* ip = &input[idx * fused_block_size];
1090 const int next_T0 = (dataInd < index_size - prefdist_T0)
1091 ? (dataInd + prefdist_T0)
1093 const int idx_pref_T0 = indices[next_T0];
1094 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1097 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1098 vop0 = _mm256_fmadd_ps(
1101 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1104 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1105 vop8 = _mm256_fmadd_ps(
1108 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1111 vop16 = _mm256_fmadd_ps(
1114 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1117 vop24 = _mm256_fmadd_ps(
1120 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1123 vop32 = _mm256_fmadd_ps(
1126 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1129 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1130 vop40 = _mm256_fmadd_ps(
1133 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1136 vop48 = _mm256_fmadd_ps(
1139 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1142 vop56 = _mm256_fmadd_ps(
1145 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1149 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1150 _mm256_storeu_ps(&op[0], vop0);
1151 _mm256_storeu_ps(&op[8], vop8);
1152 _mm256_storeu_ps(&op[16], vop16);
1153 _mm256_storeu_ps(&op[24], vop24);
1154 _mm256_storeu_ps(&op[32], vop32);
1155 _mm256_storeu_ps(&op[40], vop40);
1156 _mm256_storeu_ps(&op[48], vop48);
1157 _mm256_storeu_ps(&op[56], vop56);
1159 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1160 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1161 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1162 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1163 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1164 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1165 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1166 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1167 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1170 }
else if (block_size == 32) {
1172 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1173 float* op = &out[rangeIndex * block_size];
1174 __m256 vop0 = _mm256_setzero_ps();
1175 __m256 vop8 = _mm256_setzero_ps();
1176 __m256 vop16 = _mm256_setzero_ps();
1177 __m256 vop24 = _mm256_setzero_ps();
1178 if (dataInd + lengths[rangeIndex] > index_size) {
1181 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1183 const int idx = indices[dataInd];
1184 if (idx < 0 || idx >= data_size) {
1189 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1191 __m256 vwgt = _mm256_set1_ps(wgt);
1192 const at::Half* ip = &input[idx * fused_block_size];
1193 const int next_T0 = (dataInd < index_size - prefdist_T0)
1194 ? (dataInd + prefdist_T0)
1196 const int idx_pref_T0 = indices[next_T0];
1197 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1200 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1201 vop0 = _mm256_fmadd_ps(
1204 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1207 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1208 vop8 = _mm256_fmadd_ps(
1211 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1214 vop16 = _mm256_fmadd_ps(
1217 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1220 vop24 = _mm256_fmadd_ps(
1223 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1227 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1228 _mm256_storeu_ps(&op[0], vop0);
1229 _mm256_storeu_ps(&op[8], vop8);
1230 _mm256_storeu_ps(&op[16], vop16);
1231 _mm256_storeu_ps(&op[24], vop24);
1233 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1234 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1235 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1236 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1237 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1240 }
else if (block_size == 16) {
1242 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1243 float* op = &out[rangeIndex * block_size];
1244 __m256 vop0 = _mm256_setzero_ps();
1245 __m256 vop8 = _mm256_setzero_ps();
1246 if (dataInd + lengths[rangeIndex] > index_size) {
1249 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1251 const int idx = indices[dataInd];
1252 if (idx < 0 || idx >= data_size) {
1257 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1259 __m256 vwgt = _mm256_set1_ps(wgt);
1260 const at::Half* ip = &input[idx * fused_block_size];
1261 const int next_T0 = (dataInd < index_size - prefdist_T0)
1262 ? (dataInd + prefdist_T0)
1264 const int idx_pref_T0 = indices[next_T0];
1265 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1268 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1269 vop0 = _mm256_fmadd_ps(
1272 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1275 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1276 vop8 = _mm256_fmadd_ps(
1279 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1283 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1284 _mm256_storeu_ps(&op[0], vop0);
1285 _mm256_storeu_ps(&op[8], vop8);
1287 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1288 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1289 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1294 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1295 float* op = &out[rangeIndex * block_size];
1297 for (; j + 8 <= block_size; j += 8) {
1298 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1300 for (; j < block_size; j++) {
1303 if (dataInd + lengths[rangeIndex] > index_size) {
1306 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1308 const int idx = indices[dataInd];
1309 if (idx < 0 || idx >= data_size) {
1314 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1316 __m256 vwgt = _mm256_set1_ps(wgt);
1317 const at::Half* ip = &input[idx * fused_block_size];
1318 const int next_T0 = (dataInd < index_size - prefdist_T0)
1319 ? (dataInd + prefdist_T0)
1321 const int idx_pref_T0 = indices[next_T0];
1322 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1325 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1327 for (; j + 8 <= block_size; j += 8) {
1332 _mm256_cvtph_ps(_mm_loadu_si128(
1333 reinterpret_cast<const __m128i*>(&ip[j]))),
1334 _mm256_loadu_ps(&op[j])));
1336 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1339 for (; j < block_size; j++) {
1341 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1342 op[j] += wgt * ((
float*)(&vtmp2))[0];
1345 if (normalize_by_lengths && lengths[rangeIndex]) {
1346 float len_inv = 1.0f / lengths[rangeIndex];
1347 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1349 for (; j + 8 <= block_size; j += 8) {
1351 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1353 for (; j < block_size; j++) {
1354 op[j] = len_inv * op[j];
1359 return dataInd == index_size;
1361 bool EmbeddingLookup_int32_t_half_float_false__avx2_fma(
1362 const int64_t block_size,
1363 const int64_t output_size,
1364 const int64_t index_size,
1365 const int64_t data_size,
1369 const float* weights,
1370 const float* scale_bias,
1371 bool normalize_by_lengths,
1373 return EmbeddingLookup_int32_t_half_float__avx2_fma<false>(
1383 normalize_by_lengths,
1386 bool EmbeddingLookup_int32_t_half_float_true__avx2_fma(
1387 const int64_t block_size,
1388 const int64_t output_size,
1389 const int64_t index_size,
1390 const int64_t data_size,
1394 const float* weights,
1395 const float* scale_bias,
1396 bool normalize_by_lengths,
1398 return EmbeddingLookup_int32_t_half_float__avx2_fma<true>(
1408 normalize_by_lengths,
1412 template <
bool IS_WEIGHT_POSITIONAL>
1413 static bool EmbeddingLookup_int64_t_half_float__avx2_fma(
1414 const int64_t block_size,
1415 const int64_t output_size,
1416 const int64_t index_size,
1417 const int64_t data_size,
1419 const int64_t* indices,
1421 const float* weights,
1422 const float* scale_bias,
1423 bool normalize_by_lengths,
1425 const int64_t prefdist_T0 = 16;
1426 const int64_t fused_block_size = block_size + 0;
1427 int64_t dataInd = 0;
1428 if (block_size == 128) {
1430 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1431 float* op = &out[rangeIndex * block_size];
1432 __m256 vop0 = _mm256_setzero_ps();
1433 __m256 vop8 = _mm256_setzero_ps();
1434 __m256 vop16 = _mm256_setzero_ps();
1435 __m256 vop24 = _mm256_setzero_ps();
1436 __m256 vop32 = _mm256_setzero_ps();
1437 __m256 vop40 = _mm256_setzero_ps();
1438 __m256 vop48 = _mm256_setzero_ps();
1439 __m256 vop56 = _mm256_setzero_ps();
1440 __m256 vop64 = _mm256_setzero_ps();
1441 __m256 vop72 = _mm256_setzero_ps();
1442 __m256 vop80 = _mm256_setzero_ps();
1443 __m256 vop88 = _mm256_setzero_ps();
1444 __m256 vop96 = _mm256_setzero_ps();
1445 __m256 vop104 = _mm256_setzero_ps();
1446 __m256 vop112 = _mm256_setzero_ps();
1447 __m256 vop120 = _mm256_setzero_ps();
1448 if (dataInd + lengths[rangeIndex] > index_size) {
1451 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1453 const int64_t idx = indices[dataInd];
1454 if (idx < 0 || idx >= data_size) {
1459 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1461 __m256 vwgt = _mm256_set1_ps(wgt);
1462 const at::Half* ip = &input[idx * fused_block_size];
1463 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1464 ? (dataInd + prefdist_T0)
1466 const int64_t idx_pref_T0 = indices[next_T0];
1467 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1470 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1471 vop0 = _mm256_fmadd_ps(
1474 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1477 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1478 vop8 = _mm256_fmadd_ps(
1481 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1484 vop16 = _mm256_fmadd_ps(
1487 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1490 vop24 = _mm256_fmadd_ps(
1493 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1496 vop32 = _mm256_fmadd_ps(
1499 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1502 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1503 vop40 = _mm256_fmadd_ps(
1506 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1509 vop48 = _mm256_fmadd_ps(
1512 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1515 vop56 = _mm256_fmadd_ps(
1518 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1521 vop64 = _mm256_fmadd_ps(
1524 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1527 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1528 vop72 = _mm256_fmadd_ps(
1531 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1534 vop80 = _mm256_fmadd_ps(
1537 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1540 vop88 = _mm256_fmadd_ps(
1543 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1546 vop96 = _mm256_fmadd_ps(
1549 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1552 reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1553 vop104 = _mm256_fmadd_ps(
1556 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1559 vop112 = _mm256_fmadd_ps(
1562 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1565 vop120 = _mm256_fmadd_ps(
1568 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1572 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1573 _mm256_storeu_ps(&op[0], vop0);
1574 _mm256_storeu_ps(&op[8], vop8);
1575 _mm256_storeu_ps(&op[16], vop16);
1576 _mm256_storeu_ps(&op[24], vop24);
1577 _mm256_storeu_ps(&op[32], vop32);
1578 _mm256_storeu_ps(&op[40], vop40);
1579 _mm256_storeu_ps(&op[48], vop48);
1580 _mm256_storeu_ps(&op[56], vop56);
1581 _mm256_storeu_ps(&op[64], vop64);
1582 _mm256_storeu_ps(&op[72], vop72);
1583 _mm256_storeu_ps(&op[80], vop80);
1584 _mm256_storeu_ps(&op[88], vop88);
1585 _mm256_storeu_ps(&op[96], vop96);
1586 _mm256_storeu_ps(&op[104], vop104);
1587 _mm256_storeu_ps(&op[112], vop112);
1588 _mm256_storeu_ps(&op[120], vop120);
1590 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1591 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1592 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1593 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1594 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1595 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1596 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1597 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1598 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1599 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1600 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1601 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1602 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1603 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1604 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1605 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1606 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1609 }
else if (block_size == 64) {
1611 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1612 float* op = &out[rangeIndex * block_size];
1613 __m256 vop0 = _mm256_setzero_ps();
1614 __m256 vop8 = _mm256_setzero_ps();
1615 __m256 vop16 = _mm256_setzero_ps();
1616 __m256 vop24 = _mm256_setzero_ps();
1617 __m256 vop32 = _mm256_setzero_ps();
1618 __m256 vop40 = _mm256_setzero_ps();
1619 __m256 vop48 = _mm256_setzero_ps();
1620 __m256 vop56 = _mm256_setzero_ps();
1621 if (dataInd + lengths[rangeIndex] > index_size) {
1624 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1626 const int64_t idx = indices[dataInd];
1627 if (idx < 0 || idx >= data_size) {
1632 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1634 __m256 vwgt = _mm256_set1_ps(wgt);
1635 const at::Half* ip = &input[idx * fused_block_size];
1636 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1637 ? (dataInd + prefdist_T0)
1639 const int64_t idx_pref_T0 = indices[next_T0];
1640 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1643 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1644 vop0 = _mm256_fmadd_ps(
1647 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1650 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1651 vop8 = _mm256_fmadd_ps(
1654 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1657 vop16 = _mm256_fmadd_ps(
1660 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1663 vop24 = _mm256_fmadd_ps(
1666 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1669 vop32 = _mm256_fmadd_ps(
1672 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1675 reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1676 vop40 = _mm256_fmadd_ps(
1679 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1682 vop48 = _mm256_fmadd_ps(
1685 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1688 vop56 = _mm256_fmadd_ps(
1691 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1695 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1696 _mm256_storeu_ps(&op[0], vop0);
1697 _mm256_storeu_ps(&op[8], vop8);
1698 _mm256_storeu_ps(&op[16], vop16);
1699 _mm256_storeu_ps(&op[24], vop24);
1700 _mm256_storeu_ps(&op[32], vop32);
1701 _mm256_storeu_ps(&op[40], vop40);
1702 _mm256_storeu_ps(&op[48], vop48);
1703 _mm256_storeu_ps(&op[56], vop56);
1705 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1706 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1707 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1708 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1709 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1710 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1711 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1712 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1713 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1716 }
else if (block_size == 32) {
1718 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1719 float* op = &out[rangeIndex * block_size];
1720 __m256 vop0 = _mm256_setzero_ps();
1721 __m256 vop8 = _mm256_setzero_ps();
1722 __m256 vop16 = _mm256_setzero_ps();
1723 __m256 vop24 = _mm256_setzero_ps();
1724 if (dataInd + lengths[rangeIndex] > index_size) {
1727 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1729 const int64_t idx = indices[dataInd];
1730 if (idx < 0 || idx >= data_size) {
1735 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1737 __m256 vwgt = _mm256_set1_ps(wgt);
1738 const at::Half* ip = &input[idx * fused_block_size];
1739 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1740 ? (dataInd + prefdist_T0)
1742 const int64_t idx_pref_T0 = indices[next_T0];
1743 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1746 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1747 vop0 = _mm256_fmadd_ps(
1750 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1753 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1754 vop8 = _mm256_fmadd_ps(
1757 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1760 vop16 = _mm256_fmadd_ps(
1763 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1766 vop24 = _mm256_fmadd_ps(
1769 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1773 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1774 _mm256_storeu_ps(&op[0], vop0);
1775 _mm256_storeu_ps(&op[8], vop8);
1776 _mm256_storeu_ps(&op[16], vop16);
1777 _mm256_storeu_ps(&op[24], vop24);
1779 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1780 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1781 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1782 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1783 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1786 }
else if (block_size == 16) {
1788 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1789 float* op = &out[rangeIndex * block_size];
1790 __m256 vop0 = _mm256_setzero_ps();
1791 __m256 vop8 = _mm256_setzero_ps();
1792 if (dataInd + lengths[rangeIndex] > index_size) {
1795 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1797 const int64_t idx = indices[dataInd];
1798 if (idx < 0 || idx >= data_size) {
1803 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1805 __m256 vwgt = _mm256_set1_ps(wgt);
1806 const at::Half* ip = &input[idx * fused_block_size];
1807 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1808 ? (dataInd + prefdist_T0)
1810 const int64_t idx_pref_T0 = indices[next_T0];
1811 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1814 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1815 vop0 = _mm256_fmadd_ps(
1818 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1821 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1822 vop8 = _mm256_fmadd_ps(
1825 _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1829 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1830 _mm256_storeu_ps(&op[0], vop0);
1831 _mm256_storeu_ps(&op[8], vop8);
1833 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1834 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1835 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1840 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1841 float* op = &out[rangeIndex * block_size];
1843 for (; j + 8 <= block_size; j += 8) {
1844 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1846 for (; j < block_size; j++) {
1849 if (dataInd + lengths[rangeIndex] > index_size) {
1852 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1854 const int64_t idx = indices[dataInd];
1855 if (idx < 0 || idx >= data_size) {
1860 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1862 __m256 vwgt = _mm256_set1_ps(wgt);
1863 const at::Half* ip = &input[idx * fused_block_size];
1864 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1865 ? (dataInd + prefdist_T0)
1867 const int64_t idx_pref_T0 = indices[next_T0];
1868 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1871 const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1873 for (; j + 8 <= block_size; j += 8) {
1878 _mm256_cvtph_ps(_mm_loadu_si128(
1879 reinterpret_cast<const __m128i*>(&ip[j]))),
1880 _mm256_loadu_ps(&op[j])));
1882 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1885 for (; j < block_size; j++) {
1887 __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1888 op[j] += wgt * ((
float*)(&vtmp2))[0];
1891 if (normalize_by_lengths && lengths[rangeIndex]) {
1892 float len_inv = 1.0f / lengths[rangeIndex];
1893 __m256 vlen_inv = _mm256_set1_ps(len_inv);
1895 for (; j + 8 <= block_size; j += 8) {
1897 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1899 for (; j < block_size; j++) {
1900 op[j] = len_inv * op[j];
1905 return dataInd == index_size;
1907 bool EmbeddingLookup_int64_t_half_float_false__avx2_fma(
1908 const int64_t block_size,
1909 const int64_t output_size,
1910 const int64_t index_size,
1911 const int64_t data_size,
1913 const int64_t* indices,
1915 const float* weights,
1916 const float* scale_bias,
1917 bool normalize_by_lengths,
1919 return EmbeddingLookup_int64_t_half_float__avx2_fma<false>(
1929 normalize_by_lengths,
1932 bool EmbeddingLookup_int64_t_half_float_true__avx2_fma(
1933 const int64_t block_size,
1934 const int64_t output_size,
1935 const int64_t index_size,
1936 const int64_t data_size,
1938 const int64_t* indices,
1940 const float* weights,
1941 const float* scale_bias,
1942 bool normalize_by_lengths,
1944 return EmbeddingLookup_int64_t_half_float__avx2_fma<true>(
1954 normalize_by_lengths,
1958 template <
bool IS_WEIGHT_POSITIONAL>
1959 static bool EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1960 const int64_t block_size,
1961 const int64_t output_size,
1962 const int64_t index_size,
1963 const int64_t data_size,
1964 const uint8_t* input,
1967 const float* weights,
1968 const float* scale_bias,
1969 bool normalize_by_lengths,
1971 const int prefdist_T0 = 16;
1972 const int fused_block_size = block_size + 0;
1974 if (block_size == 128) {
1976 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1977 float* op = &out[rangeIndex * block_size];
1978 __m256 vop0 = _mm256_setzero_ps();
1979 __m256 vop8 = _mm256_setzero_ps();
1980 __m256 vop16 = _mm256_setzero_ps();
1981 __m256 vop24 = _mm256_setzero_ps();
1982 __m256 vop32 = _mm256_setzero_ps();
1983 __m256 vop40 = _mm256_setzero_ps();
1984 __m256 vop48 = _mm256_setzero_ps();
1985 __m256 vop56 = _mm256_setzero_ps();
1986 __m256 vop64 = _mm256_setzero_ps();
1987 __m256 vop72 = _mm256_setzero_ps();
1988 __m256 vop80 = _mm256_setzero_ps();
1989 __m256 vop88 = _mm256_setzero_ps();
1990 __m256 vop96 = _mm256_setzero_ps();
1991 __m256 vop104 = _mm256_setzero_ps();
1992 __m256 vop112 = _mm256_setzero_ps();
1993 __m256 vop120 = _mm256_setzero_ps();
1994 if (dataInd + lengths[rangeIndex] > index_size) {
1997 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
1999 const int idx = indices[dataInd];
2000 if (idx < 0 || idx >= data_size) {
2006 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2008 bio = wgt * scale_bias[2 * idx + 1];
2009 wgt = wgt * scale_bias[2 * idx];
2010 __m256 vbio = _mm256_set1_ps(bio);
2011 __m256 vwgt = _mm256_set1_ps(wgt);
2012 const uint8_t* ip = &input[idx * fused_block_size];
2013 const int next_T0 = (dataInd < index_size - prefdist_T0)
2014 ? (dataInd + prefdist_T0)
2016 const int idx_pref_T0 = indices[next_T0];
2017 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2020 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2021 vop0 = _mm256_fmadd_ps(
2023 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2024 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2025 _mm256_add_ps(vop0, vbio));
2027 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2028 vop8 = _mm256_fmadd_ps(
2030 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2031 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2032 _mm256_add_ps(vop8, vbio));
2034 vop16 = _mm256_fmadd_ps(
2036 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2037 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2038 _mm256_add_ps(vop16, vbio));
2040 vop24 = _mm256_fmadd_ps(
2042 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2043 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2044 _mm256_add_ps(vop24, vbio));
2046 vop32 = _mm256_fmadd_ps(
2048 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2049 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2050 _mm256_add_ps(vop32, vbio));
2052 vop40 = _mm256_fmadd_ps(
2054 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2055 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2056 _mm256_add_ps(vop40, vbio));
2058 vop48 = _mm256_fmadd_ps(
2060 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2061 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2062 _mm256_add_ps(vop48, vbio));
2064 vop56 = _mm256_fmadd_ps(
2066 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2067 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2068 _mm256_add_ps(vop56, vbio));
2070 vop64 = _mm256_fmadd_ps(
2072 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2073 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2074 _mm256_add_ps(vop64, vbio));
2076 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2077 vop72 = _mm256_fmadd_ps(
2079 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2080 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2081 _mm256_add_ps(vop72, vbio));
2083 vop80 = _mm256_fmadd_ps(
2085 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2086 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2087 _mm256_add_ps(vop80, vbio));
2089 vop88 = _mm256_fmadd_ps(
2091 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2092 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2093 _mm256_add_ps(vop88, vbio));
2095 vop96 = _mm256_fmadd_ps(
2097 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2098 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2099 _mm256_add_ps(vop96, vbio));
2101 vop104 = _mm256_fmadd_ps(
2103 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2104 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2105 _mm256_add_ps(vop104, vbio));
2107 vop112 = _mm256_fmadd_ps(
2109 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2110 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2111 _mm256_add_ps(vop112, vbio));
2113 vop120 = _mm256_fmadd_ps(
2115 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2116 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2117 _mm256_add_ps(vop120, vbio));
2120 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2121 _mm256_storeu_ps(&op[0], vop0);
2122 _mm256_storeu_ps(&op[8], vop8);
2123 _mm256_storeu_ps(&op[16], vop16);
2124 _mm256_storeu_ps(&op[24], vop24);
2125 _mm256_storeu_ps(&op[32], vop32);
2126 _mm256_storeu_ps(&op[40], vop40);
2127 _mm256_storeu_ps(&op[48], vop48);
2128 _mm256_storeu_ps(&op[56], vop56);
2129 _mm256_storeu_ps(&op[64], vop64);
2130 _mm256_storeu_ps(&op[72], vop72);
2131 _mm256_storeu_ps(&op[80], vop80);
2132 _mm256_storeu_ps(&op[88], vop88);
2133 _mm256_storeu_ps(&op[96], vop96);
2134 _mm256_storeu_ps(&op[104], vop104);
2135 _mm256_storeu_ps(&op[112], vop112);
2136 _mm256_storeu_ps(&op[120], vop120);
2138 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2139 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2140 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2141 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2142 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2143 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2144 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2145 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2146 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2147 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2148 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2149 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2150 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2151 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2152 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2153 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2154 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2157 }
else if (block_size == 64) {
2159 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2160 float* op = &out[rangeIndex * block_size];
2161 __m256 vop0 = _mm256_setzero_ps();
2162 __m256 vop8 = _mm256_setzero_ps();
2163 __m256 vop16 = _mm256_setzero_ps();
2164 __m256 vop24 = _mm256_setzero_ps();
2165 __m256 vop32 = _mm256_setzero_ps();
2166 __m256 vop40 = _mm256_setzero_ps();
2167 __m256 vop48 = _mm256_setzero_ps();
2168 __m256 vop56 = _mm256_setzero_ps();
2169 if (dataInd + lengths[rangeIndex] > index_size) {
2172 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2174 const int idx = indices[dataInd];
2175 if (idx < 0 || idx >= data_size) {
2181 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2183 bio = wgt * scale_bias[2 * idx + 1];
2184 wgt = wgt * scale_bias[2 * idx];
2185 __m256 vbio = _mm256_set1_ps(bio);
2186 __m256 vwgt = _mm256_set1_ps(wgt);
2187 const uint8_t* ip = &input[idx * fused_block_size];
2188 const int next_T0 = (dataInd < index_size - prefdist_T0)
2189 ? (dataInd + prefdist_T0)
2191 const int idx_pref_T0 = indices[next_T0];
2192 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2195 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2196 vop0 = _mm256_fmadd_ps(
2198 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2199 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2200 _mm256_add_ps(vop0, vbio));
2202 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2203 vop8 = _mm256_fmadd_ps(
2205 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2206 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2207 _mm256_add_ps(vop8, vbio));
2209 vop16 = _mm256_fmadd_ps(
2211 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2212 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2213 _mm256_add_ps(vop16, vbio));
2215 vop24 = _mm256_fmadd_ps(
2217 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2218 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2219 _mm256_add_ps(vop24, vbio));
2221 vop32 = _mm256_fmadd_ps(
2223 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2224 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2225 _mm256_add_ps(vop32, vbio));
2227 vop40 = _mm256_fmadd_ps(
2229 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2230 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2231 _mm256_add_ps(vop40, vbio));
2233 vop48 = _mm256_fmadd_ps(
2235 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2236 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2237 _mm256_add_ps(vop48, vbio));
2239 vop56 = _mm256_fmadd_ps(
2241 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2242 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2243 _mm256_add_ps(vop56, vbio));
2246 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2247 _mm256_storeu_ps(&op[0], vop0);
2248 _mm256_storeu_ps(&op[8], vop8);
2249 _mm256_storeu_ps(&op[16], vop16);
2250 _mm256_storeu_ps(&op[24], vop24);
2251 _mm256_storeu_ps(&op[32], vop32);
2252 _mm256_storeu_ps(&op[40], vop40);
2253 _mm256_storeu_ps(&op[48], vop48);
2254 _mm256_storeu_ps(&op[56], vop56);
2256 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2257 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2258 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2259 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2260 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2261 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2262 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2263 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2264 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2267 }
else if (block_size == 32) {
2269 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2270 float* op = &out[rangeIndex * block_size];
2271 __m256 vop0 = _mm256_setzero_ps();
2272 __m256 vop8 = _mm256_setzero_ps();
2273 __m256 vop16 = _mm256_setzero_ps();
2274 __m256 vop24 = _mm256_setzero_ps();
2275 if (dataInd + lengths[rangeIndex] > index_size) {
2278 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2280 const int idx = indices[dataInd];
2281 if (idx < 0 || idx >= data_size) {
2287 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2289 bio = wgt * scale_bias[2 * idx + 1];
2290 wgt = wgt * scale_bias[2 * idx];
2291 __m256 vbio = _mm256_set1_ps(bio);
2292 __m256 vwgt = _mm256_set1_ps(wgt);
2293 const uint8_t* ip = &input[idx * fused_block_size];
2294 const int next_T0 = (dataInd < index_size - prefdist_T0)
2295 ? (dataInd + prefdist_T0)
2297 const int idx_pref_T0 = indices[next_T0];
2298 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2301 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2302 vop0 = _mm256_fmadd_ps(
2304 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2305 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2306 _mm256_add_ps(vop0, vbio));
2308 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2309 vop8 = _mm256_fmadd_ps(
2311 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2312 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2313 _mm256_add_ps(vop8, vbio));
2315 vop16 = _mm256_fmadd_ps(
2317 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2318 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2319 _mm256_add_ps(vop16, vbio));
2321 vop24 = _mm256_fmadd_ps(
2323 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2324 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2325 _mm256_add_ps(vop24, vbio));
2328 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2329 _mm256_storeu_ps(&op[0], vop0);
2330 _mm256_storeu_ps(&op[8], vop8);
2331 _mm256_storeu_ps(&op[16], vop16);
2332 _mm256_storeu_ps(&op[24], vop24);
2334 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2335 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2336 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2337 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2338 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2341 }
else if (block_size == 16) {
2343 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2344 float* op = &out[rangeIndex * block_size];
2345 __m256 vop0 = _mm256_setzero_ps();
2346 __m256 vop8 = _mm256_setzero_ps();
2347 if (dataInd + lengths[rangeIndex] > index_size) {
2350 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2352 const int idx = indices[dataInd];
2353 if (idx < 0 || idx >= data_size) {
2359 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2361 bio = wgt * scale_bias[2 * idx + 1];
2362 wgt = wgt * scale_bias[2 * idx];
2363 __m256 vbio = _mm256_set1_ps(bio);
2364 __m256 vwgt = _mm256_set1_ps(wgt);
2365 const uint8_t* ip = &input[idx * fused_block_size];
2366 const int next_T0 = (dataInd < index_size - prefdist_T0)
2367 ? (dataInd + prefdist_T0)
2369 const int idx_pref_T0 = indices[next_T0];
2370 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2373 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2374 vop0 = _mm256_fmadd_ps(
2376 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2377 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2378 _mm256_add_ps(vop0, vbio));
2380 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2381 vop8 = _mm256_fmadd_ps(
2383 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2384 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2385 _mm256_add_ps(vop8, vbio));
2388 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2389 _mm256_storeu_ps(&op[0], vop0);
2390 _mm256_storeu_ps(&op[8], vop8);
2392 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2393 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2394 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2399 for (
int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2400 float* op = &out[rangeIndex * block_size];
2402 for (; j + 8 <= block_size; j += 8) {
2403 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2405 for (; j < block_size; j++) {
2408 if (dataInd + lengths[rangeIndex] > index_size) {
2411 for (
int start = dataInd; dataInd < start + lengths[rangeIndex];
2413 const int idx = indices[dataInd];
2414 if (idx < 0 || idx >= data_size) {
2420 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2422 bio = wgt * scale_bias[2 * idx + 1];
2423 wgt = wgt * scale_bias[2 * idx];
2424 __m256 vbio = _mm256_set1_ps(bio);
2425 __m256 vwgt = _mm256_set1_ps(wgt);
2426 const uint8_t* ip = &input[idx * fused_block_size];
2427 const int next_T0 = (dataInd < index_size - prefdist_T0)
2428 ? (dataInd + prefdist_T0)
2430 const int idx_pref_T0 = indices[next_T0];
2431 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2434 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2436 for (; j + 8 <= block_size; j += 8) {
2441 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2442 reinterpret_cast<const __m128i*>(&ip[j])))),
2443 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2445 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2447 for (; j < block_size; j++) {
2448 op[j] += wgt * ((float)ip[j]) + bio;
2451 if (normalize_by_lengths && lengths[rangeIndex]) {
2452 float len_inv = 1.0f / lengths[rangeIndex];
2453 __m256 vlen_inv = _mm256_set1_ps(len_inv);
2455 for (; j + 8 <= block_size; j += 8) {
2457 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2459 for (; j < block_size; j++) {
2460 op[j] = len_inv * op[j];
2465 return dataInd == index_size;
2467 bool EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2468 const int64_t block_size,
2469 const int64_t output_size,
2470 const int64_t index_size,
2471 const int64_t data_size,
2472 const uint8_t* input,
2475 const float* weights,
2476 const float* scale_bias,
2477 bool normalize_by_lengths,
2479 return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2489 normalize_by_lengths,
2492 bool EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2493 const int64_t block_size,
2494 const int64_t output_size,
2495 const int64_t index_size,
2496 const int64_t data_size,
2497 const uint8_t* input,
2500 const float* weights,
2501 const float* scale_bias,
2502 bool normalize_by_lengths,
2504 return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2514 normalize_by_lengths,
2518 template <
bool IS_WEIGHT_POSITIONAL>
2519 static bool EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2520 const int64_t block_size,
2521 const int64_t output_size,
2522 const int64_t index_size,
2523 const int64_t data_size,
2524 const uint8_t* input,
2525 const int64_t* indices,
2527 const float* weights,
2528 const float* scale_bias,
2529 bool normalize_by_lengths,
2531 const int64_t prefdist_T0 = 16;
2532 const int64_t fused_block_size = block_size + 0;
2533 int64_t dataInd = 0;
2534 if (block_size == 128) {
2536 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2537 float* op = &out[rangeIndex * block_size];
2538 __m256 vop0 = _mm256_setzero_ps();
2539 __m256 vop8 = _mm256_setzero_ps();
2540 __m256 vop16 = _mm256_setzero_ps();
2541 __m256 vop24 = _mm256_setzero_ps();
2542 __m256 vop32 = _mm256_setzero_ps();
2543 __m256 vop40 = _mm256_setzero_ps();
2544 __m256 vop48 = _mm256_setzero_ps();
2545 __m256 vop56 = _mm256_setzero_ps();
2546 __m256 vop64 = _mm256_setzero_ps();
2547 __m256 vop72 = _mm256_setzero_ps();
2548 __m256 vop80 = _mm256_setzero_ps();
2549 __m256 vop88 = _mm256_setzero_ps();
2550 __m256 vop96 = _mm256_setzero_ps();
2551 __m256 vop104 = _mm256_setzero_ps();
2552 __m256 vop112 = _mm256_setzero_ps();
2553 __m256 vop120 = _mm256_setzero_ps();
2554 if (dataInd + lengths[rangeIndex] > index_size) {
2557 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2559 const int64_t idx = indices[dataInd];
2560 if (idx < 0 || idx >= data_size) {
2566 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2568 bio = wgt * scale_bias[2 * idx + 1];
2569 wgt = wgt * scale_bias[2 * idx];
2570 __m256 vbio = _mm256_set1_ps(bio);
2571 __m256 vwgt = _mm256_set1_ps(wgt);
2572 const uint8_t* ip = &input[idx * fused_block_size];
2573 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2574 ? (dataInd + prefdist_T0)
2576 const int64_t idx_pref_T0 = indices[next_T0];
2577 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2580 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2581 vop0 = _mm256_fmadd_ps(
2583 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2584 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2585 _mm256_add_ps(vop0, vbio));
2587 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2588 vop8 = _mm256_fmadd_ps(
2590 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2591 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2592 _mm256_add_ps(vop8, vbio));
2594 vop16 = _mm256_fmadd_ps(
2596 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2597 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2598 _mm256_add_ps(vop16, vbio));
2600 vop24 = _mm256_fmadd_ps(
2602 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2603 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2604 _mm256_add_ps(vop24, vbio));
2606 vop32 = _mm256_fmadd_ps(
2608 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2609 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2610 _mm256_add_ps(vop32, vbio));
2612 vop40 = _mm256_fmadd_ps(
2614 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2615 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2616 _mm256_add_ps(vop40, vbio));
2618 vop48 = _mm256_fmadd_ps(
2620 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2621 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2622 _mm256_add_ps(vop48, vbio));
2624 vop56 = _mm256_fmadd_ps(
2626 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2627 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2628 _mm256_add_ps(vop56, vbio));
2630 vop64 = _mm256_fmadd_ps(
2632 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2633 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2634 _mm256_add_ps(vop64, vbio));
2636 reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2637 vop72 = _mm256_fmadd_ps(
2639 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2640 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2641 _mm256_add_ps(vop72, vbio));
2643 vop80 = _mm256_fmadd_ps(
2645 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2646 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2647 _mm256_add_ps(vop80, vbio));
2649 vop88 = _mm256_fmadd_ps(
2651 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2652 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2653 _mm256_add_ps(vop88, vbio));
2655 vop96 = _mm256_fmadd_ps(
2657 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2658 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2659 _mm256_add_ps(vop96, vbio));
2661 vop104 = _mm256_fmadd_ps(
2663 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2664 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2665 _mm256_add_ps(vop104, vbio));
2667 vop112 = _mm256_fmadd_ps(
2669 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2670 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2671 _mm256_add_ps(vop112, vbio));
2673 vop120 = _mm256_fmadd_ps(
2675 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2676 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2677 _mm256_add_ps(vop120, vbio));
2680 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2681 _mm256_storeu_ps(&op[0], vop0);
2682 _mm256_storeu_ps(&op[8], vop8);
2683 _mm256_storeu_ps(&op[16], vop16);
2684 _mm256_storeu_ps(&op[24], vop24);
2685 _mm256_storeu_ps(&op[32], vop32);
2686 _mm256_storeu_ps(&op[40], vop40);
2687 _mm256_storeu_ps(&op[48], vop48);
2688 _mm256_storeu_ps(&op[56], vop56);
2689 _mm256_storeu_ps(&op[64], vop64);
2690 _mm256_storeu_ps(&op[72], vop72);
2691 _mm256_storeu_ps(&op[80], vop80);
2692 _mm256_storeu_ps(&op[88], vop88);
2693 _mm256_storeu_ps(&op[96], vop96);
2694 _mm256_storeu_ps(&op[104], vop104);
2695 _mm256_storeu_ps(&op[112], vop112);
2696 _mm256_storeu_ps(&op[120], vop120);
2698 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2699 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2700 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2701 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2702 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2703 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2704 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2705 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2706 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2707 _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2708 _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2709 _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2710 _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2711 _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2712 _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2713 _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2714 _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2717 }
else if (block_size == 64) {
2719 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2720 float* op = &out[rangeIndex * block_size];
2721 __m256 vop0 = _mm256_setzero_ps();
2722 __m256 vop8 = _mm256_setzero_ps();
2723 __m256 vop16 = _mm256_setzero_ps();
2724 __m256 vop24 = _mm256_setzero_ps();
2725 __m256 vop32 = _mm256_setzero_ps();
2726 __m256 vop40 = _mm256_setzero_ps();
2727 __m256 vop48 = _mm256_setzero_ps();
2728 __m256 vop56 = _mm256_setzero_ps();
2729 if (dataInd + lengths[rangeIndex] > index_size) {
2732 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2734 const int64_t idx = indices[dataInd];
2735 if (idx < 0 || idx >= data_size) {
2741 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2743 bio = wgt * scale_bias[2 * idx + 1];
2744 wgt = wgt * scale_bias[2 * idx];
2745 __m256 vbio = _mm256_set1_ps(bio);
2746 __m256 vwgt = _mm256_set1_ps(wgt);
2747 const uint8_t* ip = &input[idx * fused_block_size];
2748 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2749 ? (dataInd + prefdist_T0)
2751 const int64_t idx_pref_T0 = indices[next_T0];
2752 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2755 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2756 vop0 = _mm256_fmadd_ps(
2758 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2759 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2760 _mm256_add_ps(vop0, vbio));
2762 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2763 vop8 = _mm256_fmadd_ps(
2765 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2766 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2767 _mm256_add_ps(vop8, vbio));
2769 vop16 = _mm256_fmadd_ps(
2771 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2772 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2773 _mm256_add_ps(vop16, vbio));
2775 vop24 = _mm256_fmadd_ps(
2777 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2778 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2779 _mm256_add_ps(vop24, vbio));
2781 vop32 = _mm256_fmadd_ps(
2783 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2784 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2785 _mm256_add_ps(vop32, vbio));
2787 vop40 = _mm256_fmadd_ps(
2789 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2790 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2791 _mm256_add_ps(vop40, vbio));
2793 vop48 = _mm256_fmadd_ps(
2795 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2796 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2797 _mm256_add_ps(vop48, vbio));
2799 vop56 = _mm256_fmadd_ps(
2801 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2802 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2803 _mm256_add_ps(vop56, vbio));
2806 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2807 _mm256_storeu_ps(&op[0], vop0);
2808 _mm256_storeu_ps(&op[8], vop8);
2809 _mm256_storeu_ps(&op[16], vop16);
2810 _mm256_storeu_ps(&op[24], vop24);
2811 _mm256_storeu_ps(&op[32], vop32);
2812 _mm256_storeu_ps(&op[40], vop40);
2813 _mm256_storeu_ps(&op[48], vop48);
2814 _mm256_storeu_ps(&op[56], vop56);
2816 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2817 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2818 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2819 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2820 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2821 _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2822 _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2823 _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2824 _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2827 }
else if (block_size == 32) {
2829 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2830 float* op = &out[rangeIndex * block_size];
2831 __m256 vop0 = _mm256_setzero_ps();
2832 __m256 vop8 = _mm256_setzero_ps();
2833 __m256 vop16 = _mm256_setzero_ps();
2834 __m256 vop24 = _mm256_setzero_ps();
2835 if (dataInd + lengths[rangeIndex] > index_size) {
2838 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2840 const int64_t idx = indices[dataInd];
2841 if (idx < 0 || idx >= data_size) {
2847 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2849 bio = wgt * scale_bias[2 * idx + 1];
2850 wgt = wgt * scale_bias[2 * idx];
2851 __m256 vbio = _mm256_set1_ps(bio);
2852 __m256 vwgt = _mm256_set1_ps(wgt);
2853 const uint8_t* ip = &input[idx * fused_block_size];
2854 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2855 ? (dataInd + prefdist_T0)
2857 const int64_t idx_pref_T0 = indices[next_T0];
2858 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2861 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2862 vop0 = _mm256_fmadd_ps(
2864 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2865 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2866 _mm256_add_ps(vop0, vbio));
2868 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2869 vop8 = _mm256_fmadd_ps(
2871 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2872 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2873 _mm256_add_ps(vop8, vbio));
2875 vop16 = _mm256_fmadd_ps(
2877 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2878 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2879 _mm256_add_ps(vop16, vbio));
2881 vop24 = _mm256_fmadd_ps(
2883 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2884 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2885 _mm256_add_ps(vop24, vbio));
2888 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2889 _mm256_storeu_ps(&op[0], vop0);
2890 _mm256_storeu_ps(&op[8], vop8);
2891 _mm256_storeu_ps(&op[16], vop16);
2892 _mm256_storeu_ps(&op[24], vop24);
2894 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2895 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2896 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2897 _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2898 _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2901 }
else if (block_size == 16) {
2903 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2904 float* op = &out[rangeIndex * block_size];
2905 __m256 vop0 = _mm256_setzero_ps();
2906 __m256 vop8 = _mm256_setzero_ps();
2907 if (dataInd + lengths[rangeIndex] > index_size) {
2910 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2912 const int64_t idx = indices[dataInd];
2913 if (idx < 0 || idx >= data_size) {
2919 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2921 bio = wgt * scale_bias[2 * idx + 1];
2922 wgt = wgt * scale_bias[2 * idx];
2923 __m256 vbio = _mm256_set1_ps(bio);
2924 __m256 vwgt = _mm256_set1_ps(wgt);
2925 const uint8_t* ip = &input[idx * fused_block_size];
2926 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2927 ? (dataInd + prefdist_T0)
2929 const int64_t idx_pref_T0 = indices[next_T0];
2930 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2933 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2934 vop0 = _mm256_fmadd_ps(
2936 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2937 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2938 _mm256_add_ps(vop0, vbio));
2940 reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2941 vop8 = _mm256_fmadd_ps(
2943 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2944 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2945 _mm256_add_ps(vop8, vbio));
2948 if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2949 _mm256_storeu_ps(&op[0], vop0);
2950 _mm256_storeu_ps(&op[8], vop8);
2952 __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2953 _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2954 _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2959 for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2960 float* op = &out[rangeIndex * block_size];
2962 for (; j + 8 <= block_size; j += 8) {
2963 _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2965 for (; j < block_size; j++) {
2968 if (dataInd + lengths[rangeIndex] > index_size) {
2971 for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2973 const int64_t idx = indices[dataInd];
2974 if (idx < 0 || idx >= data_size) {
2980 wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2982 bio = wgt * scale_bias[2 * idx + 1];
2983 wgt = wgt * scale_bias[2 * idx];
2984 __m256 vbio = _mm256_set1_ps(bio);
2985 __m256 vwgt = _mm256_set1_ps(wgt);
2986 const uint8_t* ip = &input[idx * fused_block_size];
2987 const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2988 ? (dataInd + prefdist_T0)
2990 const int64_t idx_pref_T0 = indices[next_T0];
2991 if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2994 const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2996 for (; j + 8 <= block_size; j += 8) {
3001 _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
3002 reinterpret_cast<const __m128i*>(&ip[j])))),
3003 _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
3005 reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3007 for (; j < block_size; j++) {
3008 op[j] += wgt * ((float)ip[j]) + bio;
3011 if (normalize_by_lengths && lengths[rangeIndex]) {
3012 float len_inv = 1.0f / lengths[rangeIndex];
3013 __m256 vlen_inv = _mm256_set1_ps(len_inv);
3015 for (; j + 8 <= block_size; j += 8) {
3017 &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3019 for (; j < block_size; j++) {
3020 op[j] = len_inv * op[j];
3025 return dataInd == index_size;
3027 bool EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
3028 const int64_t block_size,
3029 const int64_t output_size,
3030 const int64_t index_size,
3031 const int64_t data_size,
3032 const uint8_t* input,
3033 const int64_t* indices,
3035 const float* weights,
3036 const float* scale_bias,
3037 bool normalize_by_lengths,
3039 return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3049 normalize_by_lengths,
3052 bool EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3053 const int64_t block_size,
3054 const int64_t output_size,
3055 const int64_t index_size,
3056 const int64_t data_size,
3057 const uint8_t* input,
3058 const int64_t* indices,
3060 const float* weights,
3061 const float* scale_bias,
3062 bool normalize_by_lengths,
3064 return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3074 normalize_by_lengths,
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...