Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_avx2.cc
1 
8 #include <c10/util/Half.h>
9 #include <immintrin.h>
10 namespace caffe2 {
11 
12 template <bool IS_WEIGHT_POSITIONAL>
13 static bool EmbeddingLookup_int32_t_float_float__avx2_fma(
14  const int64_t block_size,
15  const int64_t output_size,
16  const int64_t index_size,
17  const int64_t data_size,
18  const float* input,
19  const int* indices,
20  const int* lengths,
21  const float* weights,
22  const float* scale_bias,
23  bool normalize_by_lengths,
24  float* out) {
25  const int prefdist_T0 = 16;
26  const int fused_block_size = block_size + 0;
27  int dataInd = 0;
28  if (block_size == 128) {
29  // unrolling 16 times
30  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
31  float* op = &out[rangeIndex * block_size];
32  __m256 vop0 = _mm256_setzero_ps();
33  __m256 vop8 = _mm256_setzero_ps();
34  __m256 vop16 = _mm256_setzero_ps();
35  __m256 vop24 = _mm256_setzero_ps();
36  __m256 vop32 = _mm256_setzero_ps();
37  __m256 vop40 = _mm256_setzero_ps();
38  __m256 vop48 = _mm256_setzero_ps();
39  __m256 vop56 = _mm256_setzero_ps();
40  __m256 vop64 = _mm256_setzero_ps();
41  __m256 vop72 = _mm256_setzero_ps();
42  __m256 vop80 = _mm256_setzero_ps();
43  __m256 vop88 = _mm256_setzero_ps();
44  __m256 vop96 = _mm256_setzero_ps();
45  __m256 vop104 = _mm256_setzero_ps();
46  __m256 vop112 = _mm256_setzero_ps();
47  __m256 vop120 = _mm256_setzero_ps();
48  if (dataInd + lengths[rangeIndex] > index_size) {
49  return false;
50  }
51  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
52  ++dataInd) {
53  const int idx = indices[dataInd];
54  if (idx < 0 || idx >= data_size) {
55  return false;
56  }
57  float wgt = 1.f;
58  if (weights) {
59  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
60  }
61  __m256 vwgt = _mm256_set1_ps(wgt);
62  const float* ip = &input[idx * fused_block_size];
63  const int next_T0 = (dataInd < index_size - prefdist_T0)
64  ? (dataInd + prefdist_T0)
65  : dataInd;
66  const int idx_pref_T0 = indices[next_T0];
67  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
68  return false;
69  }
70  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
71  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
72  _mm_prefetch(
73  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
74  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
75  // skip unnecessary prefetch of (&ip_next_T0[8])
76  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
77  _mm_prefetch(
78  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
79  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
80  // skip unnecessary prefetch of (&ip_next_T0[24])
81  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
82  _mm_prefetch(
83  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
84  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
85  // skip unnecessary prefetch of (&ip_next_T0[40])
86  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
87  _mm_prefetch(
88  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
89  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
90  // skip unnecessary prefetch of (&ip_next_T0[56])
91  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
92  _mm_prefetch(
93  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
94  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
95  // skip unnecessary prefetch of (&ip_next_T0[72])
96  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
97  _mm_prefetch(
98  reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
99  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
100  // skip unnecessary prefetch of (&ip_next_T0[88])
101  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
102  _mm_prefetch(
103  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
104  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
105  // skip unnecessary prefetch of (&ip_next_T0[104])
106  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
107  _mm_prefetch(
108  reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
109  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
110  // skip unnecessary prefetch of (&ip_next_T0[120])
111  }
112  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
113  _mm256_storeu_ps(&op[0], vop0);
114  _mm256_storeu_ps(&op[8], vop8);
115  _mm256_storeu_ps(&op[16], vop16);
116  _mm256_storeu_ps(&op[24], vop24);
117  _mm256_storeu_ps(&op[32], vop32);
118  _mm256_storeu_ps(&op[40], vop40);
119  _mm256_storeu_ps(&op[48], vop48);
120  _mm256_storeu_ps(&op[56], vop56);
121  _mm256_storeu_ps(&op[64], vop64);
122  _mm256_storeu_ps(&op[72], vop72);
123  _mm256_storeu_ps(&op[80], vop80);
124  _mm256_storeu_ps(&op[88], vop88);
125  _mm256_storeu_ps(&op[96], vop96);
126  _mm256_storeu_ps(&op[104], vop104);
127  _mm256_storeu_ps(&op[112], vop112);
128  _mm256_storeu_ps(&op[120], vop120);
129  } else {
130  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
131  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
132  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
133  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
134  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
135  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
136  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
137  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
138  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
139  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
140  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
141  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
142  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
143  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
144  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
145  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
146  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
147  }
148  }
149  } else if (block_size == 64) {
150  // unrolling 8 times
151  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
152  float* op = &out[rangeIndex * block_size];
153  __m256 vop0 = _mm256_setzero_ps();
154  __m256 vop8 = _mm256_setzero_ps();
155  __m256 vop16 = _mm256_setzero_ps();
156  __m256 vop24 = _mm256_setzero_ps();
157  __m256 vop32 = _mm256_setzero_ps();
158  __m256 vop40 = _mm256_setzero_ps();
159  __m256 vop48 = _mm256_setzero_ps();
160  __m256 vop56 = _mm256_setzero_ps();
161  if (dataInd + lengths[rangeIndex] > index_size) {
162  return false;
163  }
164  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
165  ++dataInd) {
166  const int idx = indices[dataInd];
167  if (idx < 0 || idx >= data_size) {
168  return false;
169  }
170  float wgt = 1.f;
171  if (weights) {
172  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
173  }
174  __m256 vwgt = _mm256_set1_ps(wgt);
175  const float* ip = &input[idx * fused_block_size];
176  const int next_T0 = (dataInd < index_size - prefdist_T0)
177  ? (dataInd + prefdist_T0)
178  : dataInd;
179  const int idx_pref_T0 = indices[next_T0];
180  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
181  return false;
182  }
183  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
184  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
185  _mm_prefetch(
186  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
187  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
188  // skip unnecessary prefetch of (&ip_next_T0[8])
189  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
190  _mm_prefetch(
191  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
192  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
193  // skip unnecessary prefetch of (&ip_next_T0[24])
194  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
195  _mm_prefetch(
196  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
197  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
198  // skip unnecessary prefetch of (&ip_next_T0[40])
199  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
200  _mm_prefetch(
201  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
202  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
203  // skip unnecessary prefetch of (&ip_next_T0[56])
204  }
205  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
206  _mm256_storeu_ps(&op[0], vop0);
207  _mm256_storeu_ps(&op[8], vop8);
208  _mm256_storeu_ps(&op[16], vop16);
209  _mm256_storeu_ps(&op[24], vop24);
210  _mm256_storeu_ps(&op[32], vop32);
211  _mm256_storeu_ps(&op[40], vop40);
212  _mm256_storeu_ps(&op[48], vop48);
213  _mm256_storeu_ps(&op[56], vop56);
214  } else {
215  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
216  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
217  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
218  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
219  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
220  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
221  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
222  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
223  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
224  }
225  }
226  } else if (block_size == 32) {
227  // unrolling 4 times
228  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
229  float* op = &out[rangeIndex * block_size];
230  __m256 vop0 = _mm256_setzero_ps();
231  __m256 vop8 = _mm256_setzero_ps();
232  __m256 vop16 = _mm256_setzero_ps();
233  __m256 vop24 = _mm256_setzero_ps();
234  if (dataInd + lengths[rangeIndex] > index_size) {
235  return false;
236  }
237  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
238  ++dataInd) {
239  const int idx = indices[dataInd];
240  if (idx < 0 || idx >= data_size) {
241  return false;
242  }
243  float wgt = 1.f;
244  if (weights) {
245  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
246  }
247  __m256 vwgt = _mm256_set1_ps(wgt);
248  const float* ip = &input[idx * fused_block_size];
249  const int next_T0 = (dataInd < index_size - prefdist_T0)
250  ? (dataInd + prefdist_T0)
251  : dataInd;
252  const int idx_pref_T0 = indices[next_T0];
253  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
254  return false;
255  }
256  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
257  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
258  _mm_prefetch(
259  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
260  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
261  // skip unnecessary prefetch of (&ip_next_T0[8])
262  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
263  _mm_prefetch(
264  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
265  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
266  // skip unnecessary prefetch of (&ip_next_T0[24])
267  }
268  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
269  _mm256_storeu_ps(&op[0], vop0);
270  _mm256_storeu_ps(&op[8], vop8);
271  _mm256_storeu_ps(&op[16], vop16);
272  _mm256_storeu_ps(&op[24], vop24);
273  } else {
274  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
275  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
276  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
277  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
278  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
279  }
280  }
281  } else if (block_size == 16) {
282  // unrolling 2 times
283  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
284  float* op = &out[rangeIndex * block_size];
285  __m256 vop0 = _mm256_setzero_ps();
286  __m256 vop8 = _mm256_setzero_ps();
287  if (dataInd + lengths[rangeIndex] > index_size) {
288  return false;
289  }
290  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
291  ++dataInd) {
292  const int idx = indices[dataInd];
293  if (idx < 0 || idx >= data_size) {
294  return false;
295  }
296  float wgt = 1.f;
297  if (weights) {
298  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
299  }
300  __m256 vwgt = _mm256_set1_ps(wgt);
301  const float* ip = &input[idx * fused_block_size];
302  const int next_T0 = (dataInd < index_size - prefdist_T0)
303  ? (dataInd + prefdist_T0)
304  : dataInd;
305  const int idx_pref_T0 = indices[next_T0];
306  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
307  return false;
308  }
309  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
310  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
311  _mm_prefetch(
312  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
313  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
314  // skip unnecessary prefetch of (&ip_next_T0[8])
315  }
316  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
317  _mm256_storeu_ps(&op[0], vop0);
318  _mm256_storeu_ps(&op[8], vop8);
319  } else {
320  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
321  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
322  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
323  }
324  }
325  } else {
326  // generic code
327  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
328  float* op = &out[rangeIndex * block_size];
329  int64_t j = 0;
330  for (; j + 8 <= block_size; j += 8) {
331  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
332  }
333  for (; j < block_size; j++) {
334  op[j] = 0.0f;
335  }
336  if (dataInd + lengths[rangeIndex] > index_size) {
337  return false;
338  }
339  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
340  ++dataInd) {
341  const int idx = indices[dataInd];
342  if (idx < 0 || idx >= data_size) {
343  return false;
344  }
345  float wgt = 1.f;
346  if (weights) {
347  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
348  }
349  __m256 vwgt = _mm256_set1_ps(wgt);
350  const float* ip = &input[idx * fused_block_size];
351  const int next_T0 = (dataInd < index_size - prefdist_T0)
352  ? (dataInd + prefdist_T0)
353  : dataInd;
354  const int idx_pref_T0 = indices[next_T0];
355  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
356  return false;
357  }
358  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
359  j = 0;
360  for (; j + 8 <= block_size; j += 8) {
361  _mm256_storeu_ps(
362  &op[j],
363  _mm256_fmadd_ps(
364  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
365  _mm_prefetch(
366  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
367  }
368  for (; j < block_size; j++) {
369  op[j] += wgt * ip[j];
370  }
371  }
372  if (normalize_by_lengths && lengths[rangeIndex]) {
373  float len_inv = 1.0f / lengths[rangeIndex];
374  __m256 vlen_inv = _mm256_set1_ps(len_inv);
375  j = 0;
376  for (; j + 8 <= block_size; j += 8) {
377  _mm256_storeu_ps(
378  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
379  }
380  for (; j < block_size; j++) {
381  op[j] = len_inv * op[j];
382  }
383  }
384  }
385  }
386  return dataInd == index_size;
387 }
388 bool EmbeddingLookup_int32_t_float_float_false__avx2_fma(
389  const int64_t block_size,
390  const int64_t output_size,
391  const int64_t index_size,
392  const int64_t data_size,
393  const float* input,
394  const int* indices,
395  const int* lengths,
396  const float* weights,
397  const float* scale_bias,
398  bool normalize_by_lengths,
399  float* out) {
400  return EmbeddingLookup_int32_t_float_float__avx2_fma<false>(
401  block_size,
402  output_size,
403  index_size,
404  data_size,
405  input,
406  indices,
407  lengths,
408  weights,
409  scale_bias,
410  normalize_by_lengths,
411  out);
412 }
413 bool EmbeddingLookup_int32_t_float_float_true__avx2_fma(
414  const int64_t block_size,
415  const int64_t output_size,
416  const int64_t index_size,
417  const int64_t data_size,
418  const float* input,
419  const int* indices,
420  const int* lengths,
421  const float* weights,
422  const float* scale_bias,
423  bool normalize_by_lengths,
424  float* out) {
425  return EmbeddingLookup_int32_t_float_float__avx2_fma<true>(
426  block_size,
427  output_size,
428  index_size,
429  data_size,
430  input,
431  indices,
432  lengths,
433  weights,
434  scale_bias,
435  normalize_by_lengths,
436  out);
437 }
438 
439 template <bool IS_WEIGHT_POSITIONAL>
440 static bool EmbeddingLookup_int64_t_float_float__avx2_fma(
441  const int64_t block_size,
442  const int64_t output_size,
443  const int64_t index_size,
444  const int64_t data_size,
445  const float* input,
446  const int64_t* indices,
447  const int* lengths,
448  const float* weights,
449  const float* scale_bias,
450  bool normalize_by_lengths,
451  float* out) {
452  const int64_t prefdist_T0 = 16;
453  const int64_t fused_block_size = block_size + 0;
454  int64_t dataInd = 0;
455  if (block_size == 128) {
456  // unrolling 16 times
457  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
458  float* op = &out[rangeIndex * block_size];
459  __m256 vop0 = _mm256_setzero_ps();
460  __m256 vop8 = _mm256_setzero_ps();
461  __m256 vop16 = _mm256_setzero_ps();
462  __m256 vop24 = _mm256_setzero_ps();
463  __m256 vop32 = _mm256_setzero_ps();
464  __m256 vop40 = _mm256_setzero_ps();
465  __m256 vop48 = _mm256_setzero_ps();
466  __m256 vop56 = _mm256_setzero_ps();
467  __m256 vop64 = _mm256_setzero_ps();
468  __m256 vop72 = _mm256_setzero_ps();
469  __m256 vop80 = _mm256_setzero_ps();
470  __m256 vop88 = _mm256_setzero_ps();
471  __m256 vop96 = _mm256_setzero_ps();
472  __m256 vop104 = _mm256_setzero_ps();
473  __m256 vop112 = _mm256_setzero_ps();
474  __m256 vop120 = _mm256_setzero_ps();
475  if (dataInd + lengths[rangeIndex] > index_size) {
476  return false;
477  }
478  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
479  ++dataInd) {
480  const int64_t idx = indices[dataInd];
481  if (idx < 0 || idx >= data_size) {
482  return false;
483  }
484  float wgt = 1.f;
485  if (weights) {
486  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
487  }
488  __m256 vwgt = _mm256_set1_ps(wgt);
489  const float* ip = &input[idx * fused_block_size];
490  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
491  ? (dataInd + prefdist_T0)
492  : dataInd;
493  const int64_t idx_pref_T0 = indices[next_T0];
494  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
495  return false;
496  }
497  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
498  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
499  _mm_prefetch(
500  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
501  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
502  // skip unnecessary prefetch of (&ip_next_T0[8])
503  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
504  _mm_prefetch(
505  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
506  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
507  // skip unnecessary prefetch of (&ip_next_T0[24])
508  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
509  _mm_prefetch(
510  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
511  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
512  // skip unnecessary prefetch of (&ip_next_T0[40])
513  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
514  _mm_prefetch(
515  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
516  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
517  // skip unnecessary prefetch of (&ip_next_T0[56])
518  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
519  _mm_prefetch(
520  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
521  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
522  // skip unnecessary prefetch of (&ip_next_T0[72])
523  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
524  _mm_prefetch(
525  reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
526  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
527  // skip unnecessary prefetch of (&ip_next_T0[88])
528  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
529  _mm_prefetch(
530  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
531  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
532  // skip unnecessary prefetch of (&ip_next_T0[104])
533  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
534  _mm_prefetch(
535  reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
536  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
537  // skip unnecessary prefetch of (&ip_next_T0[120])
538  }
539  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
540  _mm256_storeu_ps(&op[0], vop0);
541  _mm256_storeu_ps(&op[8], vop8);
542  _mm256_storeu_ps(&op[16], vop16);
543  _mm256_storeu_ps(&op[24], vop24);
544  _mm256_storeu_ps(&op[32], vop32);
545  _mm256_storeu_ps(&op[40], vop40);
546  _mm256_storeu_ps(&op[48], vop48);
547  _mm256_storeu_ps(&op[56], vop56);
548  _mm256_storeu_ps(&op[64], vop64);
549  _mm256_storeu_ps(&op[72], vop72);
550  _mm256_storeu_ps(&op[80], vop80);
551  _mm256_storeu_ps(&op[88], vop88);
552  _mm256_storeu_ps(&op[96], vop96);
553  _mm256_storeu_ps(&op[104], vop104);
554  _mm256_storeu_ps(&op[112], vop112);
555  _mm256_storeu_ps(&op[120], vop120);
556  } else {
557  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
558  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
559  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
560  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
561  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
562  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
563  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
564  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
565  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
566  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
567  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
568  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
569  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
570  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
571  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
572  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
573  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
574  }
575  }
576  } else if (block_size == 64) {
577  // unrolling 8 times
578  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
579  float* op = &out[rangeIndex * block_size];
580  __m256 vop0 = _mm256_setzero_ps();
581  __m256 vop8 = _mm256_setzero_ps();
582  __m256 vop16 = _mm256_setzero_ps();
583  __m256 vop24 = _mm256_setzero_ps();
584  __m256 vop32 = _mm256_setzero_ps();
585  __m256 vop40 = _mm256_setzero_ps();
586  __m256 vop48 = _mm256_setzero_ps();
587  __m256 vop56 = _mm256_setzero_ps();
588  if (dataInd + lengths[rangeIndex] > index_size) {
589  return false;
590  }
591  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
592  ++dataInd) {
593  const int64_t idx = indices[dataInd];
594  if (idx < 0 || idx >= data_size) {
595  return false;
596  }
597  float wgt = 1.f;
598  if (weights) {
599  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
600  }
601  __m256 vwgt = _mm256_set1_ps(wgt);
602  const float* ip = &input[idx * fused_block_size];
603  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
604  ? (dataInd + prefdist_T0)
605  : dataInd;
606  const int64_t idx_pref_T0 = indices[next_T0];
607  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
608  return false;
609  }
610  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
611  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
612  _mm_prefetch(
613  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
614  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
615  // skip unnecessary prefetch of (&ip_next_T0[8])
616  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
617  _mm_prefetch(
618  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
619  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
620  // skip unnecessary prefetch of (&ip_next_T0[24])
621  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
622  _mm_prefetch(
623  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
624  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
625  // skip unnecessary prefetch of (&ip_next_T0[40])
626  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
627  _mm_prefetch(
628  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
629  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
630  // skip unnecessary prefetch of (&ip_next_T0[56])
631  }
632  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
633  _mm256_storeu_ps(&op[0], vop0);
634  _mm256_storeu_ps(&op[8], vop8);
635  _mm256_storeu_ps(&op[16], vop16);
636  _mm256_storeu_ps(&op[24], vop24);
637  _mm256_storeu_ps(&op[32], vop32);
638  _mm256_storeu_ps(&op[40], vop40);
639  _mm256_storeu_ps(&op[48], vop48);
640  _mm256_storeu_ps(&op[56], vop56);
641  } else {
642  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
643  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
644  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
645  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
646  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
647  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
648  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
649  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
650  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
651  }
652  }
653  } else if (block_size == 32) {
654  // unrolling 4 times
655  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
656  float* op = &out[rangeIndex * block_size];
657  __m256 vop0 = _mm256_setzero_ps();
658  __m256 vop8 = _mm256_setzero_ps();
659  __m256 vop16 = _mm256_setzero_ps();
660  __m256 vop24 = _mm256_setzero_ps();
661  if (dataInd + lengths[rangeIndex] > index_size) {
662  return false;
663  }
664  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
665  ++dataInd) {
666  const int64_t idx = indices[dataInd];
667  if (idx < 0 || idx >= data_size) {
668  return false;
669  }
670  float wgt = 1.f;
671  if (weights) {
672  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
673  }
674  __m256 vwgt = _mm256_set1_ps(wgt);
675  const float* ip = &input[idx * fused_block_size];
676  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
677  ? (dataInd + prefdist_T0)
678  : dataInd;
679  const int64_t idx_pref_T0 = indices[next_T0];
680  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
681  return false;
682  }
683  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
684  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
685  _mm_prefetch(
686  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
687  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
688  // skip unnecessary prefetch of (&ip_next_T0[8])
689  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
690  _mm_prefetch(
691  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
692  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
693  // skip unnecessary prefetch of (&ip_next_T0[24])
694  }
695  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
696  _mm256_storeu_ps(&op[0], vop0);
697  _mm256_storeu_ps(&op[8], vop8);
698  _mm256_storeu_ps(&op[16], vop16);
699  _mm256_storeu_ps(&op[24], vop24);
700  } else {
701  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
702  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
703  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
704  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
705  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
706  }
707  }
708  } else if (block_size == 16) {
709  // unrolling 2 times
710  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
711  float* op = &out[rangeIndex * block_size];
712  __m256 vop0 = _mm256_setzero_ps();
713  __m256 vop8 = _mm256_setzero_ps();
714  if (dataInd + lengths[rangeIndex] > index_size) {
715  return false;
716  }
717  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
718  ++dataInd) {
719  const int64_t idx = indices[dataInd];
720  if (idx < 0 || idx >= data_size) {
721  return false;
722  }
723  float wgt = 1.f;
724  if (weights) {
725  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
726  }
727  __m256 vwgt = _mm256_set1_ps(wgt);
728  const float* ip = &input[idx * fused_block_size];
729  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
730  ? (dataInd + prefdist_T0)
731  : dataInd;
732  const int64_t idx_pref_T0 = indices[next_T0];
733  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
734  return false;
735  }
736  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
737  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
738  _mm_prefetch(
739  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
740  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
741  // skip unnecessary prefetch of (&ip_next_T0[8])
742  }
743  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
744  _mm256_storeu_ps(&op[0], vop0);
745  _mm256_storeu_ps(&op[8], vop8);
746  } else {
747  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
748  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
749  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
750  }
751  }
752  } else {
753  // generic code
754  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
755  float* op = &out[rangeIndex * block_size];
756  int64_t j = 0;
757  for (; j + 8 <= block_size; j += 8) {
758  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
759  }
760  for (; j < block_size; j++) {
761  op[j] = 0.0f;
762  }
763  if (dataInd + lengths[rangeIndex] > index_size) {
764  return false;
765  }
766  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
767  ++dataInd) {
768  const int64_t idx = indices[dataInd];
769  if (idx < 0 || idx >= data_size) {
770  return false;
771  }
772  float wgt = 1.f;
773  if (weights) {
774  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
775  }
776  __m256 vwgt = _mm256_set1_ps(wgt);
777  const float* ip = &input[idx * fused_block_size];
778  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
779  ? (dataInd + prefdist_T0)
780  : dataInd;
781  const int64_t idx_pref_T0 = indices[next_T0];
782  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
783  return false;
784  }
785  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
786  j = 0;
787  for (; j + 8 <= block_size; j += 8) {
788  _mm256_storeu_ps(
789  &op[j],
790  _mm256_fmadd_ps(
791  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
792  _mm_prefetch(
793  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
794  }
795  for (; j < block_size; j++) {
796  op[j] += wgt * ip[j];
797  }
798  }
799  if (normalize_by_lengths && lengths[rangeIndex]) {
800  float len_inv = 1.0f / lengths[rangeIndex];
801  __m256 vlen_inv = _mm256_set1_ps(len_inv);
802  j = 0;
803  for (; j + 8 <= block_size; j += 8) {
804  _mm256_storeu_ps(
805  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
806  }
807  for (; j < block_size; j++) {
808  op[j] = len_inv * op[j];
809  }
810  }
811  }
812  }
813  return dataInd == index_size;
814 }
815 bool EmbeddingLookup_int64_t_float_float_false__avx2_fma(
816  const int64_t block_size,
817  const int64_t output_size,
818  const int64_t index_size,
819  const int64_t data_size,
820  const float* input,
821  const int64_t* indices,
822  const int* lengths,
823  const float* weights,
824  const float* scale_bias,
825  bool normalize_by_lengths,
826  float* out) {
827  return EmbeddingLookup_int64_t_float_float__avx2_fma<false>(
828  block_size,
829  output_size,
830  index_size,
831  data_size,
832  input,
833  indices,
834  lengths,
835  weights,
836  scale_bias,
837  normalize_by_lengths,
838  out);
839 }
840 bool EmbeddingLookup_int64_t_float_float_true__avx2_fma(
841  const int64_t block_size,
842  const int64_t output_size,
843  const int64_t index_size,
844  const int64_t data_size,
845  const float* input,
846  const int64_t* indices,
847  const int* lengths,
848  const float* weights,
849  const float* scale_bias,
850  bool normalize_by_lengths,
851  float* out) {
852  return EmbeddingLookup_int64_t_float_float__avx2_fma<true>(
853  block_size,
854  output_size,
855  index_size,
856  data_size,
857  input,
858  indices,
859  lengths,
860  weights,
861  scale_bias,
862  normalize_by_lengths,
863  out);
864 }
865 
866 template <bool IS_WEIGHT_POSITIONAL>
867 static bool EmbeddingLookup_int32_t_half_float__avx2_fma(
868  const int64_t block_size,
869  const int64_t output_size,
870  const int64_t index_size,
871  const int64_t data_size,
872  const at::Half* input,
873  const int* indices,
874  const int* lengths,
875  const float* weights,
876  const float* scale_bias,
877  bool normalize_by_lengths,
878  float* out) {
879  const int prefdist_T0 = 16;
880  const int fused_block_size = block_size + 0;
881  int dataInd = 0;
882  if (block_size == 128) {
883  // unrolling 16 times
884  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
885  float* op = &out[rangeIndex * block_size];
886  __m256 vop0 = _mm256_setzero_ps();
887  __m256 vop8 = _mm256_setzero_ps();
888  __m256 vop16 = _mm256_setzero_ps();
889  __m256 vop24 = _mm256_setzero_ps();
890  __m256 vop32 = _mm256_setzero_ps();
891  __m256 vop40 = _mm256_setzero_ps();
892  __m256 vop48 = _mm256_setzero_ps();
893  __m256 vop56 = _mm256_setzero_ps();
894  __m256 vop64 = _mm256_setzero_ps();
895  __m256 vop72 = _mm256_setzero_ps();
896  __m256 vop80 = _mm256_setzero_ps();
897  __m256 vop88 = _mm256_setzero_ps();
898  __m256 vop96 = _mm256_setzero_ps();
899  __m256 vop104 = _mm256_setzero_ps();
900  __m256 vop112 = _mm256_setzero_ps();
901  __m256 vop120 = _mm256_setzero_ps();
902  if (dataInd + lengths[rangeIndex] > index_size) {
903  return false;
904  }
905  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
906  ++dataInd) {
907  const int idx = indices[dataInd];
908  if (idx < 0 || idx >= data_size) {
909  return false;
910  }
911  float wgt = 1.f;
912  if (weights) {
913  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
914  }
915  __m256 vwgt = _mm256_set1_ps(wgt);
916  const at::Half* ip = &input[idx * fused_block_size];
917  const int next_T0 = (dataInd < index_size - prefdist_T0)
918  ? (dataInd + prefdist_T0)
919  : dataInd;
920  const int idx_pref_T0 = indices[next_T0];
921  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
922  return false;
923  }
924  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
925  vop0 = _mm256_fmadd_ps(
926  vwgt,
927  _mm256_cvtph_ps(
928  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
929  vop0);
930  _mm_prefetch(
931  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
932  vop8 = _mm256_fmadd_ps(
933  vwgt,
934  _mm256_cvtph_ps(
935  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
936  vop8);
937  // skip unnecessary prefetch of (&ip_next_T0[8])
938  vop16 = _mm256_fmadd_ps(
939  vwgt,
940  _mm256_cvtph_ps(
941  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
942  vop16);
943  // skip unnecessary prefetch of (&ip_next_T0[16])
944  vop24 = _mm256_fmadd_ps(
945  vwgt,
946  _mm256_cvtph_ps(
947  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
948  vop24);
949  // skip unnecessary prefetch of (&ip_next_T0[24])
950  vop32 = _mm256_fmadd_ps(
951  vwgt,
952  _mm256_cvtph_ps(
953  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
954  vop32);
955  _mm_prefetch(
956  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
957  vop40 = _mm256_fmadd_ps(
958  vwgt,
959  _mm256_cvtph_ps(
960  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
961  vop40);
962  // skip unnecessary prefetch of (&ip_next_T0[40])
963  vop48 = _mm256_fmadd_ps(
964  vwgt,
965  _mm256_cvtph_ps(
966  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
967  vop48);
968  // skip unnecessary prefetch of (&ip_next_T0[48])
969  vop56 = _mm256_fmadd_ps(
970  vwgt,
971  _mm256_cvtph_ps(
972  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
973  vop56);
974  // skip unnecessary prefetch of (&ip_next_T0[56])
975  vop64 = _mm256_fmadd_ps(
976  vwgt,
977  _mm256_cvtph_ps(
978  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
979  vop64);
980  _mm_prefetch(
981  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
982  vop72 = _mm256_fmadd_ps(
983  vwgt,
984  _mm256_cvtph_ps(
985  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
986  vop72);
987  // skip unnecessary prefetch of (&ip_next_T0[72])
988  vop80 = _mm256_fmadd_ps(
989  vwgt,
990  _mm256_cvtph_ps(
991  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
992  vop80);
993  // skip unnecessary prefetch of (&ip_next_T0[80])
994  vop88 = _mm256_fmadd_ps(
995  vwgt,
996  _mm256_cvtph_ps(
997  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
998  vop88);
999  // skip unnecessary prefetch of (&ip_next_T0[88])
1000  vop96 = _mm256_fmadd_ps(
1001  vwgt,
1002  _mm256_cvtph_ps(
1003  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1004  vop96);
1005  _mm_prefetch(
1006  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1007  vop104 = _mm256_fmadd_ps(
1008  vwgt,
1009  _mm256_cvtph_ps(
1010  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1011  vop104);
1012  // skip unnecessary prefetch of (&ip_next_T0[104])
1013  vop112 = _mm256_fmadd_ps(
1014  vwgt,
1015  _mm256_cvtph_ps(
1016  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1017  vop112);
1018  // skip unnecessary prefetch of (&ip_next_T0[112])
1019  vop120 = _mm256_fmadd_ps(
1020  vwgt,
1021  _mm256_cvtph_ps(
1022  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1023  vop120);
1024  // skip unnecessary prefetch of (&ip_next_T0[120])
1025  }
1026  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1027  _mm256_storeu_ps(&op[0], vop0);
1028  _mm256_storeu_ps(&op[8], vop8);
1029  _mm256_storeu_ps(&op[16], vop16);
1030  _mm256_storeu_ps(&op[24], vop24);
1031  _mm256_storeu_ps(&op[32], vop32);
1032  _mm256_storeu_ps(&op[40], vop40);
1033  _mm256_storeu_ps(&op[48], vop48);
1034  _mm256_storeu_ps(&op[56], vop56);
1035  _mm256_storeu_ps(&op[64], vop64);
1036  _mm256_storeu_ps(&op[72], vop72);
1037  _mm256_storeu_ps(&op[80], vop80);
1038  _mm256_storeu_ps(&op[88], vop88);
1039  _mm256_storeu_ps(&op[96], vop96);
1040  _mm256_storeu_ps(&op[104], vop104);
1041  _mm256_storeu_ps(&op[112], vop112);
1042  _mm256_storeu_ps(&op[120], vop120);
1043  } else {
1044  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1045  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1046  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1047  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1048  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1049  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1050  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1051  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1052  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1053  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1054  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1055  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1056  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1057  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1058  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1059  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1060  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1061  }
1062  }
1063  } else if (block_size == 64) {
1064  // unrolling 8 times
1065  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1066  float* op = &out[rangeIndex * block_size];
1067  __m256 vop0 = _mm256_setzero_ps();
1068  __m256 vop8 = _mm256_setzero_ps();
1069  __m256 vop16 = _mm256_setzero_ps();
1070  __m256 vop24 = _mm256_setzero_ps();
1071  __m256 vop32 = _mm256_setzero_ps();
1072  __m256 vop40 = _mm256_setzero_ps();
1073  __m256 vop48 = _mm256_setzero_ps();
1074  __m256 vop56 = _mm256_setzero_ps();
1075  if (dataInd + lengths[rangeIndex] > index_size) {
1076  return false;
1077  }
1078  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1079  ++dataInd) {
1080  const int idx = indices[dataInd];
1081  if (idx < 0 || idx >= data_size) {
1082  return false;
1083  }
1084  float wgt = 1.f;
1085  if (weights) {
1086  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1087  }
1088  __m256 vwgt = _mm256_set1_ps(wgt);
1089  const at::Half* ip = &input[idx * fused_block_size];
1090  const int next_T0 = (dataInd < index_size - prefdist_T0)
1091  ? (dataInd + prefdist_T0)
1092  : dataInd;
1093  const int idx_pref_T0 = indices[next_T0];
1094  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1095  return false;
1096  }
1097  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1098  vop0 = _mm256_fmadd_ps(
1099  vwgt,
1100  _mm256_cvtph_ps(
1101  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1102  vop0);
1103  _mm_prefetch(
1104  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1105  vop8 = _mm256_fmadd_ps(
1106  vwgt,
1107  _mm256_cvtph_ps(
1108  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1109  vop8);
1110  // skip unnecessary prefetch of (&ip_next_T0[8])
1111  vop16 = _mm256_fmadd_ps(
1112  vwgt,
1113  _mm256_cvtph_ps(
1114  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1115  vop16);
1116  // skip unnecessary prefetch of (&ip_next_T0[16])
1117  vop24 = _mm256_fmadd_ps(
1118  vwgt,
1119  _mm256_cvtph_ps(
1120  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1121  vop24);
1122  // skip unnecessary prefetch of (&ip_next_T0[24])
1123  vop32 = _mm256_fmadd_ps(
1124  vwgt,
1125  _mm256_cvtph_ps(
1126  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1127  vop32);
1128  _mm_prefetch(
1129  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1130  vop40 = _mm256_fmadd_ps(
1131  vwgt,
1132  _mm256_cvtph_ps(
1133  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1134  vop40);
1135  // skip unnecessary prefetch of (&ip_next_T0[40])
1136  vop48 = _mm256_fmadd_ps(
1137  vwgt,
1138  _mm256_cvtph_ps(
1139  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1140  vop48);
1141  // skip unnecessary prefetch of (&ip_next_T0[48])
1142  vop56 = _mm256_fmadd_ps(
1143  vwgt,
1144  _mm256_cvtph_ps(
1145  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1146  vop56);
1147  // skip unnecessary prefetch of (&ip_next_T0[56])
1148  }
1149  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1150  _mm256_storeu_ps(&op[0], vop0);
1151  _mm256_storeu_ps(&op[8], vop8);
1152  _mm256_storeu_ps(&op[16], vop16);
1153  _mm256_storeu_ps(&op[24], vop24);
1154  _mm256_storeu_ps(&op[32], vop32);
1155  _mm256_storeu_ps(&op[40], vop40);
1156  _mm256_storeu_ps(&op[48], vop48);
1157  _mm256_storeu_ps(&op[56], vop56);
1158  } else {
1159  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1160  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1161  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1162  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1163  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1164  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1165  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1166  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1167  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1168  }
1169  }
1170  } else if (block_size == 32) {
1171  // unrolling 4 times
1172  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1173  float* op = &out[rangeIndex * block_size];
1174  __m256 vop0 = _mm256_setzero_ps();
1175  __m256 vop8 = _mm256_setzero_ps();
1176  __m256 vop16 = _mm256_setzero_ps();
1177  __m256 vop24 = _mm256_setzero_ps();
1178  if (dataInd + lengths[rangeIndex] > index_size) {
1179  return false;
1180  }
1181  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1182  ++dataInd) {
1183  const int idx = indices[dataInd];
1184  if (idx < 0 || idx >= data_size) {
1185  return false;
1186  }
1187  float wgt = 1.f;
1188  if (weights) {
1189  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1190  }
1191  __m256 vwgt = _mm256_set1_ps(wgt);
1192  const at::Half* ip = &input[idx * fused_block_size];
1193  const int next_T0 = (dataInd < index_size - prefdist_T0)
1194  ? (dataInd + prefdist_T0)
1195  : dataInd;
1196  const int idx_pref_T0 = indices[next_T0];
1197  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1198  return false;
1199  }
1200  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1201  vop0 = _mm256_fmadd_ps(
1202  vwgt,
1203  _mm256_cvtph_ps(
1204  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1205  vop0);
1206  _mm_prefetch(
1207  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1208  vop8 = _mm256_fmadd_ps(
1209  vwgt,
1210  _mm256_cvtph_ps(
1211  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1212  vop8);
1213  // skip unnecessary prefetch of (&ip_next_T0[8])
1214  vop16 = _mm256_fmadd_ps(
1215  vwgt,
1216  _mm256_cvtph_ps(
1217  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1218  vop16);
1219  // skip unnecessary prefetch of (&ip_next_T0[16])
1220  vop24 = _mm256_fmadd_ps(
1221  vwgt,
1222  _mm256_cvtph_ps(
1223  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1224  vop24);
1225  // skip unnecessary prefetch of (&ip_next_T0[24])
1226  }
1227  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1228  _mm256_storeu_ps(&op[0], vop0);
1229  _mm256_storeu_ps(&op[8], vop8);
1230  _mm256_storeu_ps(&op[16], vop16);
1231  _mm256_storeu_ps(&op[24], vop24);
1232  } else {
1233  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1234  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1235  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1236  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1237  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1238  }
1239  }
1240  } else if (block_size == 16) {
1241  // unrolling 2 times
1242  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1243  float* op = &out[rangeIndex * block_size];
1244  __m256 vop0 = _mm256_setzero_ps();
1245  __m256 vop8 = _mm256_setzero_ps();
1246  if (dataInd + lengths[rangeIndex] > index_size) {
1247  return false;
1248  }
1249  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1250  ++dataInd) {
1251  const int idx = indices[dataInd];
1252  if (idx < 0 || idx >= data_size) {
1253  return false;
1254  }
1255  float wgt = 1.f;
1256  if (weights) {
1257  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1258  }
1259  __m256 vwgt = _mm256_set1_ps(wgt);
1260  const at::Half* ip = &input[idx * fused_block_size];
1261  const int next_T0 = (dataInd < index_size - prefdist_T0)
1262  ? (dataInd + prefdist_T0)
1263  : dataInd;
1264  const int idx_pref_T0 = indices[next_T0];
1265  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1266  return false;
1267  }
1268  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1269  vop0 = _mm256_fmadd_ps(
1270  vwgt,
1271  _mm256_cvtph_ps(
1272  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1273  vop0);
1274  _mm_prefetch(
1275  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1276  vop8 = _mm256_fmadd_ps(
1277  vwgt,
1278  _mm256_cvtph_ps(
1279  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1280  vop8);
1281  // skip unnecessary prefetch of (&ip_next_T0[8])
1282  }
1283  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1284  _mm256_storeu_ps(&op[0], vop0);
1285  _mm256_storeu_ps(&op[8], vop8);
1286  } else {
1287  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1288  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1289  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1290  }
1291  }
1292  } else {
1293  // generic code
1294  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1295  float* op = &out[rangeIndex * block_size];
1296  int64_t j = 0;
1297  for (; j + 8 <= block_size; j += 8) {
1298  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1299  }
1300  for (; j < block_size; j++) {
1301  op[j] = 0.0f;
1302  }
1303  if (dataInd + lengths[rangeIndex] > index_size) {
1304  return false;
1305  }
1306  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1307  ++dataInd) {
1308  const int idx = indices[dataInd];
1309  if (idx < 0 || idx >= data_size) {
1310  return false;
1311  }
1312  float wgt = 1.f;
1313  if (weights) {
1314  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1315  }
1316  __m256 vwgt = _mm256_set1_ps(wgt);
1317  const at::Half* ip = &input[idx * fused_block_size];
1318  const int next_T0 = (dataInd < index_size - prefdist_T0)
1319  ? (dataInd + prefdist_T0)
1320  : dataInd;
1321  const int idx_pref_T0 = indices[next_T0];
1322  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1323  return false;
1324  }
1325  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1326  j = 0;
1327  for (; j + 8 <= block_size; j += 8) {
1328  _mm256_storeu_ps(
1329  &op[j],
1330  _mm256_fmadd_ps(
1331  vwgt,
1332  _mm256_cvtph_ps(_mm_loadu_si128(
1333  reinterpret_cast<const __m128i*>(&ip[j]))),
1334  _mm256_loadu_ps(&op[j])));
1335  _mm_prefetch(
1336  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1337  }
1338  alignas(64) at::Half vtmp1[8];
1339  for (; j < block_size; j++) {
1340  vtmp1[0] = ip[j];
1341  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1342  op[j] += wgt * ((float*)(&vtmp2))[0];
1343  }
1344  }
1345  if (normalize_by_lengths && lengths[rangeIndex]) {
1346  float len_inv = 1.0f / lengths[rangeIndex];
1347  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1348  j = 0;
1349  for (; j + 8 <= block_size; j += 8) {
1350  _mm256_storeu_ps(
1351  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1352  }
1353  for (; j < block_size; j++) {
1354  op[j] = len_inv * op[j];
1355  }
1356  }
1357  }
1358  }
1359  return dataInd == index_size;
1360 }
1361 bool EmbeddingLookup_int32_t_half_float_false__avx2_fma(
1362  const int64_t block_size,
1363  const int64_t output_size,
1364  const int64_t index_size,
1365  const int64_t data_size,
1366  const at::Half* input,
1367  const int* indices,
1368  const int* lengths,
1369  const float* weights,
1370  const float* scale_bias,
1371  bool normalize_by_lengths,
1372  float* out) {
1373  return EmbeddingLookup_int32_t_half_float__avx2_fma<false>(
1374  block_size,
1375  output_size,
1376  index_size,
1377  data_size,
1378  input,
1379  indices,
1380  lengths,
1381  weights,
1382  scale_bias,
1383  normalize_by_lengths,
1384  out);
1385 }
1386 bool EmbeddingLookup_int32_t_half_float_true__avx2_fma(
1387  const int64_t block_size,
1388  const int64_t output_size,
1389  const int64_t index_size,
1390  const int64_t data_size,
1391  const at::Half* input,
1392  const int* indices,
1393  const int* lengths,
1394  const float* weights,
1395  const float* scale_bias,
1396  bool normalize_by_lengths,
1397  float* out) {
1398  return EmbeddingLookup_int32_t_half_float__avx2_fma<true>(
1399  block_size,
1400  output_size,
1401  index_size,
1402  data_size,
1403  input,
1404  indices,
1405  lengths,
1406  weights,
1407  scale_bias,
1408  normalize_by_lengths,
1409  out);
1410 }
1411 
1412 template <bool IS_WEIGHT_POSITIONAL>
1413 static bool EmbeddingLookup_int64_t_half_float__avx2_fma(
1414  const int64_t block_size,
1415  const int64_t output_size,
1416  const int64_t index_size,
1417  const int64_t data_size,
1418  const at::Half* input,
1419  const int64_t* indices,
1420  const int* lengths,
1421  const float* weights,
1422  const float* scale_bias,
1423  bool normalize_by_lengths,
1424  float* out) {
1425  const int64_t prefdist_T0 = 16;
1426  const int64_t fused_block_size = block_size + 0;
1427  int64_t dataInd = 0;
1428  if (block_size == 128) {
1429  // unrolling 16 times
1430  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1431  float* op = &out[rangeIndex * block_size];
1432  __m256 vop0 = _mm256_setzero_ps();
1433  __m256 vop8 = _mm256_setzero_ps();
1434  __m256 vop16 = _mm256_setzero_ps();
1435  __m256 vop24 = _mm256_setzero_ps();
1436  __m256 vop32 = _mm256_setzero_ps();
1437  __m256 vop40 = _mm256_setzero_ps();
1438  __m256 vop48 = _mm256_setzero_ps();
1439  __m256 vop56 = _mm256_setzero_ps();
1440  __m256 vop64 = _mm256_setzero_ps();
1441  __m256 vop72 = _mm256_setzero_ps();
1442  __m256 vop80 = _mm256_setzero_ps();
1443  __m256 vop88 = _mm256_setzero_ps();
1444  __m256 vop96 = _mm256_setzero_ps();
1445  __m256 vop104 = _mm256_setzero_ps();
1446  __m256 vop112 = _mm256_setzero_ps();
1447  __m256 vop120 = _mm256_setzero_ps();
1448  if (dataInd + lengths[rangeIndex] > index_size) {
1449  return false;
1450  }
1451  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1452  ++dataInd) {
1453  const int64_t idx = indices[dataInd];
1454  if (idx < 0 || idx >= data_size) {
1455  return false;
1456  }
1457  float wgt = 1.f;
1458  if (weights) {
1459  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1460  }
1461  __m256 vwgt = _mm256_set1_ps(wgt);
1462  const at::Half* ip = &input[idx * fused_block_size];
1463  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1464  ? (dataInd + prefdist_T0)
1465  : dataInd;
1466  const int64_t idx_pref_T0 = indices[next_T0];
1467  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1468  return false;
1469  }
1470  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1471  vop0 = _mm256_fmadd_ps(
1472  vwgt,
1473  _mm256_cvtph_ps(
1474  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1475  vop0);
1476  _mm_prefetch(
1477  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1478  vop8 = _mm256_fmadd_ps(
1479  vwgt,
1480  _mm256_cvtph_ps(
1481  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1482  vop8);
1483  // skip unnecessary prefetch of (&ip_next_T0[8])
1484  vop16 = _mm256_fmadd_ps(
1485  vwgt,
1486  _mm256_cvtph_ps(
1487  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1488  vop16);
1489  // skip unnecessary prefetch of (&ip_next_T0[16])
1490  vop24 = _mm256_fmadd_ps(
1491  vwgt,
1492  _mm256_cvtph_ps(
1493  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1494  vop24);
1495  // skip unnecessary prefetch of (&ip_next_T0[24])
1496  vop32 = _mm256_fmadd_ps(
1497  vwgt,
1498  _mm256_cvtph_ps(
1499  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1500  vop32);
1501  _mm_prefetch(
1502  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1503  vop40 = _mm256_fmadd_ps(
1504  vwgt,
1505  _mm256_cvtph_ps(
1506  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1507  vop40);
1508  // skip unnecessary prefetch of (&ip_next_T0[40])
1509  vop48 = _mm256_fmadd_ps(
1510  vwgt,
1511  _mm256_cvtph_ps(
1512  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1513  vop48);
1514  // skip unnecessary prefetch of (&ip_next_T0[48])
1515  vop56 = _mm256_fmadd_ps(
1516  vwgt,
1517  _mm256_cvtph_ps(
1518  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1519  vop56);
1520  // skip unnecessary prefetch of (&ip_next_T0[56])
1521  vop64 = _mm256_fmadd_ps(
1522  vwgt,
1523  _mm256_cvtph_ps(
1524  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1525  vop64);
1526  _mm_prefetch(
1527  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1528  vop72 = _mm256_fmadd_ps(
1529  vwgt,
1530  _mm256_cvtph_ps(
1531  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1532  vop72);
1533  // skip unnecessary prefetch of (&ip_next_T0[72])
1534  vop80 = _mm256_fmadd_ps(
1535  vwgt,
1536  _mm256_cvtph_ps(
1537  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1538  vop80);
1539  // skip unnecessary prefetch of (&ip_next_T0[80])
1540  vop88 = _mm256_fmadd_ps(
1541  vwgt,
1542  _mm256_cvtph_ps(
1543  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1544  vop88);
1545  // skip unnecessary prefetch of (&ip_next_T0[88])
1546  vop96 = _mm256_fmadd_ps(
1547  vwgt,
1548  _mm256_cvtph_ps(
1549  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1550  vop96);
1551  _mm_prefetch(
1552  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1553  vop104 = _mm256_fmadd_ps(
1554  vwgt,
1555  _mm256_cvtph_ps(
1556  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1557  vop104);
1558  // skip unnecessary prefetch of (&ip_next_T0[104])
1559  vop112 = _mm256_fmadd_ps(
1560  vwgt,
1561  _mm256_cvtph_ps(
1562  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1563  vop112);
1564  // skip unnecessary prefetch of (&ip_next_T0[112])
1565  vop120 = _mm256_fmadd_ps(
1566  vwgt,
1567  _mm256_cvtph_ps(
1568  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1569  vop120);
1570  // skip unnecessary prefetch of (&ip_next_T0[120])
1571  }
1572  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1573  _mm256_storeu_ps(&op[0], vop0);
1574  _mm256_storeu_ps(&op[8], vop8);
1575  _mm256_storeu_ps(&op[16], vop16);
1576  _mm256_storeu_ps(&op[24], vop24);
1577  _mm256_storeu_ps(&op[32], vop32);
1578  _mm256_storeu_ps(&op[40], vop40);
1579  _mm256_storeu_ps(&op[48], vop48);
1580  _mm256_storeu_ps(&op[56], vop56);
1581  _mm256_storeu_ps(&op[64], vop64);
1582  _mm256_storeu_ps(&op[72], vop72);
1583  _mm256_storeu_ps(&op[80], vop80);
1584  _mm256_storeu_ps(&op[88], vop88);
1585  _mm256_storeu_ps(&op[96], vop96);
1586  _mm256_storeu_ps(&op[104], vop104);
1587  _mm256_storeu_ps(&op[112], vop112);
1588  _mm256_storeu_ps(&op[120], vop120);
1589  } else {
1590  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1591  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1592  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1593  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1594  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1595  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1596  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1597  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1598  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1599  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1600  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1601  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1602  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1603  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1604  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1605  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1606  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1607  }
1608  }
1609  } else if (block_size == 64) {
1610  // unrolling 8 times
1611  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1612  float* op = &out[rangeIndex * block_size];
1613  __m256 vop0 = _mm256_setzero_ps();
1614  __m256 vop8 = _mm256_setzero_ps();
1615  __m256 vop16 = _mm256_setzero_ps();
1616  __m256 vop24 = _mm256_setzero_ps();
1617  __m256 vop32 = _mm256_setzero_ps();
1618  __m256 vop40 = _mm256_setzero_ps();
1619  __m256 vop48 = _mm256_setzero_ps();
1620  __m256 vop56 = _mm256_setzero_ps();
1621  if (dataInd + lengths[rangeIndex] > index_size) {
1622  return false;
1623  }
1624  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1625  ++dataInd) {
1626  const int64_t idx = indices[dataInd];
1627  if (idx < 0 || idx >= data_size) {
1628  return false;
1629  }
1630  float wgt = 1.f;
1631  if (weights) {
1632  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1633  }
1634  __m256 vwgt = _mm256_set1_ps(wgt);
1635  const at::Half* ip = &input[idx * fused_block_size];
1636  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1637  ? (dataInd + prefdist_T0)
1638  : dataInd;
1639  const int64_t idx_pref_T0 = indices[next_T0];
1640  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1641  return false;
1642  }
1643  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1644  vop0 = _mm256_fmadd_ps(
1645  vwgt,
1646  _mm256_cvtph_ps(
1647  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1648  vop0);
1649  _mm_prefetch(
1650  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1651  vop8 = _mm256_fmadd_ps(
1652  vwgt,
1653  _mm256_cvtph_ps(
1654  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1655  vop8);
1656  // skip unnecessary prefetch of (&ip_next_T0[8])
1657  vop16 = _mm256_fmadd_ps(
1658  vwgt,
1659  _mm256_cvtph_ps(
1660  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1661  vop16);
1662  // skip unnecessary prefetch of (&ip_next_T0[16])
1663  vop24 = _mm256_fmadd_ps(
1664  vwgt,
1665  _mm256_cvtph_ps(
1666  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1667  vop24);
1668  // skip unnecessary prefetch of (&ip_next_T0[24])
1669  vop32 = _mm256_fmadd_ps(
1670  vwgt,
1671  _mm256_cvtph_ps(
1672  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1673  vop32);
1674  _mm_prefetch(
1675  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1676  vop40 = _mm256_fmadd_ps(
1677  vwgt,
1678  _mm256_cvtph_ps(
1679  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1680  vop40);
1681  // skip unnecessary prefetch of (&ip_next_T0[40])
1682  vop48 = _mm256_fmadd_ps(
1683  vwgt,
1684  _mm256_cvtph_ps(
1685  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1686  vop48);
1687  // skip unnecessary prefetch of (&ip_next_T0[48])
1688  vop56 = _mm256_fmadd_ps(
1689  vwgt,
1690  _mm256_cvtph_ps(
1691  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1692  vop56);
1693  // skip unnecessary prefetch of (&ip_next_T0[56])
1694  }
1695  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1696  _mm256_storeu_ps(&op[0], vop0);
1697  _mm256_storeu_ps(&op[8], vop8);
1698  _mm256_storeu_ps(&op[16], vop16);
1699  _mm256_storeu_ps(&op[24], vop24);
1700  _mm256_storeu_ps(&op[32], vop32);
1701  _mm256_storeu_ps(&op[40], vop40);
1702  _mm256_storeu_ps(&op[48], vop48);
1703  _mm256_storeu_ps(&op[56], vop56);
1704  } else {
1705  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1706  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1707  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1708  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1709  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1710  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1711  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1712  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1713  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1714  }
1715  }
1716  } else if (block_size == 32) {
1717  // unrolling 4 times
1718  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1719  float* op = &out[rangeIndex * block_size];
1720  __m256 vop0 = _mm256_setzero_ps();
1721  __m256 vop8 = _mm256_setzero_ps();
1722  __m256 vop16 = _mm256_setzero_ps();
1723  __m256 vop24 = _mm256_setzero_ps();
1724  if (dataInd + lengths[rangeIndex] > index_size) {
1725  return false;
1726  }
1727  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1728  ++dataInd) {
1729  const int64_t idx = indices[dataInd];
1730  if (idx < 0 || idx >= data_size) {
1731  return false;
1732  }
1733  float wgt = 1.f;
1734  if (weights) {
1735  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1736  }
1737  __m256 vwgt = _mm256_set1_ps(wgt);
1738  const at::Half* ip = &input[idx * fused_block_size];
1739  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1740  ? (dataInd + prefdist_T0)
1741  : dataInd;
1742  const int64_t idx_pref_T0 = indices[next_T0];
1743  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1744  return false;
1745  }
1746  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1747  vop0 = _mm256_fmadd_ps(
1748  vwgt,
1749  _mm256_cvtph_ps(
1750  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1751  vop0);
1752  _mm_prefetch(
1753  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1754  vop8 = _mm256_fmadd_ps(
1755  vwgt,
1756  _mm256_cvtph_ps(
1757  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1758  vop8);
1759  // skip unnecessary prefetch of (&ip_next_T0[8])
1760  vop16 = _mm256_fmadd_ps(
1761  vwgt,
1762  _mm256_cvtph_ps(
1763  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1764  vop16);
1765  // skip unnecessary prefetch of (&ip_next_T0[16])
1766  vop24 = _mm256_fmadd_ps(
1767  vwgt,
1768  _mm256_cvtph_ps(
1769  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1770  vop24);
1771  // skip unnecessary prefetch of (&ip_next_T0[24])
1772  }
1773  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1774  _mm256_storeu_ps(&op[0], vop0);
1775  _mm256_storeu_ps(&op[8], vop8);
1776  _mm256_storeu_ps(&op[16], vop16);
1777  _mm256_storeu_ps(&op[24], vop24);
1778  } else {
1779  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1780  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1781  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1782  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1783  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1784  }
1785  }
1786  } else if (block_size == 16) {
1787  // unrolling 2 times
1788  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1789  float* op = &out[rangeIndex * block_size];
1790  __m256 vop0 = _mm256_setzero_ps();
1791  __m256 vop8 = _mm256_setzero_ps();
1792  if (dataInd + lengths[rangeIndex] > index_size) {
1793  return false;
1794  }
1795  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1796  ++dataInd) {
1797  const int64_t idx = indices[dataInd];
1798  if (idx < 0 || idx >= data_size) {
1799  return false;
1800  }
1801  float wgt = 1.f;
1802  if (weights) {
1803  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1804  }
1805  __m256 vwgt = _mm256_set1_ps(wgt);
1806  const at::Half* ip = &input[idx * fused_block_size];
1807  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1808  ? (dataInd + prefdist_T0)
1809  : dataInd;
1810  const int64_t idx_pref_T0 = indices[next_T0];
1811  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1812  return false;
1813  }
1814  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1815  vop0 = _mm256_fmadd_ps(
1816  vwgt,
1817  _mm256_cvtph_ps(
1818  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1819  vop0);
1820  _mm_prefetch(
1821  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1822  vop8 = _mm256_fmadd_ps(
1823  vwgt,
1824  _mm256_cvtph_ps(
1825  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1826  vop8);
1827  // skip unnecessary prefetch of (&ip_next_T0[8])
1828  }
1829  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1830  _mm256_storeu_ps(&op[0], vop0);
1831  _mm256_storeu_ps(&op[8], vop8);
1832  } else {
1833  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1834  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1835  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1836  }
1837  }
1838  } else {
1839  // generic code
1840  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1841  float* op = &out[rangeIndex * block_size];
1842  int64_t j = 0;
1843  for (; j + 8 <= block_size; j += 8) {
1844  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1845  }
1846  for (; j < block_size; j++) {
1847  op[j] = 0.0f;
1848  }
1849  if (dataInd + lengths[rangeIndex] > index_size) {
1850  return false;
1851  }
1852  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1853  ++dataInd) {
1854  const int64_t idx = indices[dataInd];
1855  if (idx < 0 || idx >= data_size) {
1856  return false;
1857  }
1858  float wgt = 1.f;
1859  if (weights) {
1860  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1861  }
1862  __m256 vwgt = _mm256_set1_ps(wgt);
1863  const at::Half* ip = &input[idx * fused_block_size];
1864  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1865  ? (dataInd + prefdist_T0)
1866  : dataInd;
1867  const int64_t idx_pref_T0 = indices[next_T0];
1868  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1869  return false;
1870  }
1871  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1872  j = 0;
1873  for (; j + 8 <= block_size; j += 8) {
1874  _mm256_storeu_ps(
1875  &op[j],
1876  _mm256_fmadd_ps(
1877  vwgt,
1878  _mm256_cvtph_ps(_mm_loadu_si128(
1879  reinterpret_cast<const __m128i*>(&ip[j]))),
1880  _mm256_loadu_ps(&op[j])));
1881  _mm_prefetch(
1882  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1883  }
1884  alignas(64) at::Half vtmp1[8];
1885  for (; j < block_size; j++) {
1886  vtmp1[0] = ip[j];
1887  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1888  op[j] += wgt * ((float*)(&vtmp2))[0];
1889  }
1890  }
1891  if (normalize_by_lengths && lengths[rangeIndex]) {
1892  float len_inv = 1.0f / lengths[rangeIndex];
1893  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1894  j = 0;
1895  for (; j + 8 <= block_size; j += 8) {
1896  _mm256_storeu_ps(
1897  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1898  }
1899  for (; j < block_size; j++) {
1900  op[j] = len_inv * op[j];
1901  }
1902  }
1903  }
1904  }
1905  return dataInd == index_size;
1906 }
1907 bool EmbeddingLookup_int64_t_half_float_false__avx2_fma(
1908  const int64_t block_size,
1909  const int64_t output_size,
1910  const int64_t index_size,
1911  const int64_t data_size,
1912  const at::Half* input,
1913  const int64_t* indices,
1914  const int* lengths,
1915  const float* weights,
1916  const float* scale_bias,
1917  bool normalize_by_lengths,
1918  float* out) {
1919  return EmbeddingLookup_int64_t_half_float__avx2_fma<false>(
1920  block_size,
1921  output_size,
1922  index_size,
1923  data_size,
1924  input,
1925  indices,
1926  lengths,
1927  weights,
1928  scale_bias,
1929  normalize_by_lengths,
1930  out);
1931 }
1932 bool EmbeddingLookup_int64_t_half_float_true__avx2_fma(
1933  const int64_t block_size,
1934  const int64_t output_size,
1935  const int64_t index_size,
1936  const int64_t data_size,
1937  const at::Half* input,
1938  const int64_t* indices,
1939  const int* lengths,
1940  const float* weights,
1941  const float* scale_bias,
1942  bool normalize_by_lengths,
1943  float* out) {
1944  return EmbeddingLookup_int64_t_half_float__avx2_fma<true>(
1945  block_size,
1946  output_size,
1947  index_size,
1948  data_size,
1949  input,
1950  indices,
1951  lengths,
1952  weights,
1953  scale_bias,
1954  normalize_by_lengths,
1955  out);
1956 }
1957 
1958 template <bool IS_WEIGHT_POSITIONAL>
1959 static bool EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1960  const int64_t block_size,
1961  const int64_t output_size,
1962  const int64_t index_size,
1963  const int64_t data_size,
1964  const uint8_t* input,
1965  const int* indices,
1966  const int* lengths,
1967  const float* weights,
1968  const float* scale_bias,
1969  bool normalize_by_lengths,
1970  float* out) {
1971  const int prefdist_T0 = 16;
1972  const int fused_block_size = block_size + 0;
1973  int dataInd = 0;
1974  if (block_size == 128) {
1975  // unrolling 16 times
1976  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1977  float* op = &out[rangeIndex * block_size];
1978  __m256 vop0 = _mm256_setzero_ps();
1979  __m256 vop8 = _mm256_setzero_ps();
1980  __m256 vop16 = _mm256_setzero_ps();
1981  __m256 vop24 = _mm256_setzero_ps();
1982  __m256 vop32 = _mm256_setzero_ps();
1983  __m256 vop40 = _mm256_setzero_ps();
1984  __m256 vop48 = _mm256_setzero_ps();
1985  __m256 vop56 = _mm256_setzero_ps();
1986  __m256 vop64 = _mm256_setzero_ps();
1987  __m256 vop72 = _mm256_setzero_ps();
1988  __m256 vop80 = _mm256_setzero_ps();
1989  __m256 vop88 = _mm256_setzero_ps();
1990  __m256 vop96 = _mm256_setzero_ps();
1991  __m256 vop104 = _mm256_setzero_ps();
1992  __m256 vop112 = _mm256_setzero_ps();
1993  __m256 vop120 = _mm256_setzero_ps();
1994  if (dataInd + lengths[rangeIndex] > index_size) {
1995  return false;
1996  }
1997  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1998  ++dataInd) {
1999  const int idx = indices[dataInd];
2000  if (idx < 0 || idx >= data_size) {
2001  return false;
2002  }
2003  float wgt = 1.f;
2004  float bio;
2005  if (weights) {
2006  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2007  }
2008  bio = wgt * scale_bias[2 * idx + 1];
2009  wgt = wgt * scale_bias[2 * idx];
2010  __m256 vbio = _mm256_set1_ps(bio);
2011  __m256 vwgt = _mm256_set1_ps(wgt);
2012  const uint8_t* ip = &input[idx * fused_block_size];
2013  const int next_T0 = (dataInd < index_size - prefdist_T0)
2014  ? (dataInd + prefdist_T0)
2015  : dataInd;
2016  const int idx_pref_T0 = indices[next_T0];
2017  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2018  return false;
2019  }
2020  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2021  vop0 = _mm256_fmadd_ps(
2022  vwgt,
2023  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2024  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2025  _mm256_add_ps(vop0, vbio));
2026  _mm_prefetch(
2027  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2028  vop8 = _mm256_fmadd_ps(
2029  vwgt,
2030  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2031  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2032  _mm256_add_ps(vop8, vbio));
2033  // skip unnecessary prefetch of (&ip_next_T0[8])
2034  vop16 = _mm256_fmadd_ps(
2035  vwgt,
2036  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2037  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2038  _mm256_add_ps(vop16, vbio));
2039  // skip unnecessary prefetch of (&ip_next_T0[16])
2040  vop24 = _mm256_fmadd_ps(
2041  vwgt,
2042  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2043  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2044  _mm256_add_ps(vop24, vbio));
2045  // skip unnecessary prefetch of (&ip_next_T0[24])
2046  vop32 = _mm256_fmadd_ps(
2047  vwgt,
2048  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2049  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2050  _mm256_add_ps(vop32, vbio));
2051  // skip unnecessary prefetch of (&ip_next_T0[32])
2052  vop40 = _mm256_fmadd_ps(
2053  vwgt,
2054  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2055  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2056  _mm256_add_ps(vop40, vbio));
2057  // skip unnecessary prefetch of (&ip_next_T0[40])
2058  vop48 = _mm256_fmadd_ps(
2059  vwgt,
2060  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2061  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2062  _mm256_add_ps(vop48, vbio));
2063  // skip unnecessary prefetch of (&ip_next_T0[48])
2064  vop56 = _mm256_fmadd_ps(
2065  vwgt,
2066  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2067  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2068  _mm256_add_ps(vop56, vbio));
2069  // skip unnecessary prefetch of (&ip_next_T0[56])
2070  vop64 = _mm256_fmadd_ps(
2071  vwgt,
2072  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2073  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2074  _mm256_add_ps(vop64, vbio));
2075  _mm_prefetch(
2076  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2077  vop72 = _mm256_fmadd_ps(
2078  vwgt,
2079  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2080  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2081  _mm256_add_ps(vop72, vbio));
2082  // skip unnecessary prefetch of (&ip_next_T0[72])
2083  vop80 = _mm256_fmadd_ps(
2084  vwgt,
2085  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2086  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2087  _mm256_add_ps(vop80, vbio));
2088  // skip unnecessary prefetch of (&ip_next_T0[80])
2089  vop88 = _mm256_fmadd_ps(
2090  vwgt,
2091  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2092  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2093  _mm256_add_ps(vop88, vbio));
2094  // skip unnecessary prefetch of (&ip_next_T0[88])
2095  vop96 = _mm256_fmadd_ps(
2096  vwgt,
2097  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2098  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2099  _mm256_add_ps(vop96, vbio));
2100  // skip unnecessary prefetch of (&ip_next_T0[96])
2101  vop104 = _mm256_fmadd_ps(
2102  vwgt,
2103  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2104  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2105  _mm256_add_ps(vop104, vbio));
2106  // skip unnecessary prefetch of (&ip_next_T0[104])
2107  vop112 = _mm256_fmadd_ps(
2108  vwgt,
2109  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2110  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2111  _mm256_add_ps(vop112, vbio));
2112  // skip unnecessary prefetch of (&ip_next_T0[112])
2113  vop120 = _mm256_fmadd_ps(
2114  vwgt,
2115  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2116  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2117  _mm256_add_ps(vop120, vbio));
2118  // skip unnecessary prefetch of (&ip_next_T0[120])
2119  }
2120  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2121  _mm256_storeu_ps(&op[0], vop0);
2122  _mm256_storeu_ps(&op[8], vop8);
2123  _mm256_storeu_ps(&op[16], vop16);
2124  _mm256_storeu_ps(&op[24], vop24);
2125  _mm256_storeu_ps(&op[32], vop32);
2126  _mm256_storeu_ps(&op[40], vop40);
2127  _mm256_storeu_ps(&op[48], vop48);
2128  _mm256_storeu_ps(&op[56], vop56);
2129  _mm256_storeu_ps(&op[64], vop64);
2130  _mm256_storeu_ps(&op[72], vop72);
2131  _mm256_storeu_ps(&op[80], vop80);
2132  _mm256_storeu_ps(&op[88], vop88);
2133  _mm256_storeu_ps(&op[96], vop96);
2134  _mm256_storeu_ps(&op[104], vop104);
2135  _mm256_storeu_ps(&op[112], vop112);
2136  _mm256_storeu_ps(&op[120], vop120);
2137  } else {
2138  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2139  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2140  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2141  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2142  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2143  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2144  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2145  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2146  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2147  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2148  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2149  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2150  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2151  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2152  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2153  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2154  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2155  }
2156  }
2157  } else if (block_size == 64) {
2158  // unrolling 8 times
2159  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2160  float* op = &out[rangeIndex * block_size];
2161  __m256 vop0 = _mm256_setzero_ps();
2162  __m256 vop8 = _mm256_setzero_ps();
2163  __m256 vop16 = _mm256_setzero_ps();
2164  __m256 vop24 = _mm256_setzero_ps();
2165  __m256 vop32 = _mm256_setzero_ps();
2166  __m256 vop40 = _mm256_setzero_ps();
2167  __m256 vop48 = _mm256_setzero_ps();
2168  __m256 vop56 = _mm256_setzero_ps();
2169  if (dataInd + lengths[rangeIndex] > index_size) {
2170  return false;
2171  }
2172  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2173  ++dataInd) {
2174  const int idx = indices[dataInd];
2175  if (idx < 0 || idx >= data_size) {
2176  return false;
2177  }
2178  float wgt = 1.f;
2179  float bio;
2180  if (weights) {
2181  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2182  }
2183  bio = wgt * scale_bias[2 * idx + 1];
2184  wgt = wgt * scale_bias[2 * idx];
2185  __m256 vbio = _mm256_set1_ps(bio);
2186  __m256 vwgt = _mm256_set1_ps(wgt);
2187  const uint8_t* ip = &input[idx * fused_block_size];
2188  const int next_T0 = (dataInd < index_size - prefdist_T0)
2189  ? (dataInd + prefdist_T0)
2190  : dataInd;
2191  const int idx_pref_T0 = indices[next_T0];
2192  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2193  return false;
2194  }
2195  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2196  vop0 = _mm256_fmadd_ps(
2197  vwgt,
2198  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2199  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2200  _mm256_add_ps(vop0, vbio));
2201  _mm_prefetch(
2202  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2203  vop8 = _mm256_fmadd_ps(
2204  vwgt,
2205  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2206  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2207  _mm256_add_ps(vop8, vbio));
2208  // skip unnecessary prefetch of (&ip_next_T0[8])
2209  vop16 = _mm256_fmadd_ps(
2210  vwgt,
2211  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2212  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2213  _mm256_add_ps(vop16, vbio));
2214  // skip unnecessary prefetch of (&ip_next_T0[16])
2215  vop24 = _mm256_fmadd_ps(
2216  vwgt,
2217  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2218  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2219  _mm256_add_ps(vop24, vbio));
2220  // skip unnecessary prefetch of (&ip_next_T0[24])
2221  vop32 = _mm256_fmadd_ps(
2222  vwgt,
2223  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2224  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2225  _mm256_add_ps(vop32, vbio));
2226  // skip unnecessary prefetch of (&ip_next_T0[32])
2227  vop40 = _mm256_fmadd_ps(
2228  vwgt,
2229  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2230  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2231  _mm256_add_ps(vop40, vbio));
2232  // skip unnecessary prefetch of (&ip_next_T0[40])
2233  vop48 = _mm256_fmadd_ps(
2234  vwgt,
2235  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2236  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2237  _mm256_add_ps(vop48, vbio));
2238  // skip unnecessary prefetch of (&ip_next_T0[48])
2239  vop56 = _mm256_fmadd_ps(
2240  vwgt,
2241  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2242  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2243  _mm256_add_ps(vop56, vbio));
2244  // skip unnecessary prefetch of (&ip_next_T0[56])
2245  }
2246  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2247  _mm256_storeu_ps(&op[0], vop0);
2248  _mm256_storeu_ps(&op[8], vop8);
2249  _mm256_storeu_ps(&op[16], vop16);
2250  _mm256_storeu_ps(&op[24], vop24);
2251  _mm256_storeu_ps(&op[32], vop32);
2252  _mm256_storeu_ps(&op[40], vop40);
2253  _mm256_storeu_ps(&op[48], vop48);
2254  _mm256_storeu_ps(&op[56], vop56);
2255  } else {
2256  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2257  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2258  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2259  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2260  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2261  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2262  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2263  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2264  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2265  }
2266  }
2267  } else if (block_size == 32) {
2268  // unrolling 4 times
2269  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2270  float* op = &out[rangeIndex * block_size];
2271  __m256 vop0 = _mm256_setzero_ps();
2272  __m256 vop8 = _mm256_setzero_ps();
2273  __m256 vop16 = _mm256_setzero_ps();
2274  __m256 vop24 = _mm256_setzero_ps();
2275  if (dataInd + lengths[rangeIndex] > index_size) {
2276  return false;
2277  }
2278  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2279  ++dataInd) {
2280  const int idx = indices[dataInd];
2281  if (idx < 0 || idx >= data_size) {
2282  return false;
2283  }
2284  float wgt = 1.f;
2285  float bio;
2286  if (weights) {
2287  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2288  }
2289  bio = wgt * scale_bias[2 * idx + 1];
2290  wgt = wgt * scale_bias[2 * idx];
2291  __m256 vbio = _mm256_set1_ps(bio);
2292  __m256 vwgt = _mm256_set1_ps(wgt);
2293  const uint8_t* ip = &input[idx * fused_block_size];
2294  const int next_T0 = (dataInd < index_size - prefdist_T0)
2295  ? (dataInd + prefdist_T0)
2296  : dataInd;
2297  const int idx_pref_T0 = indices[next_T0];
2298  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2299  return false;
2300  }
2301  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2302  vop0 = _mm256_fmadd_ps(
2303  vwgt,
2304  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2305  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2306  _mm256_add_ps(vop0, vbio));
2307  _mm_prefetch(
2308  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2309  vop8 = _mm256_fmadd_ps(
2310  vwgt,
2311  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2312  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2313  _mm256_add_ps(vop8, vbio));
2314  // skip unnecessary prefetch of (&ip_next_T0[8])
2315  vop16 = _mm256_fmadd_ps(
2316  vwgt,
2317  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2318  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2319  _mm256_add_ps(vop16, vbio));
2320  // skip unnecessary prefetch of (&ip_next_T0[16])
2321  vop24 = _mm256_fmadd_ps(
2322  vwgt,
2323  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2324  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2325  _mm256_add_ps(vop24, vbio));
2326  // skip unnecessary prefetch of (&ip_next_T0[24])
2327  }
2328  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2329  _mm256_storeu_ps(&op[0], vop0);
2330  _mm256_storeu_ps(&op[8], vop8);
2331  _mm256_storeu_ps(&op[16], vop16);
2332  _mm256_storeu_ps(&op[24], vop24);
2333  } else {
2334  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2335  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2336  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2337  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2338  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2339  }
2340  }
2341  } else if (block_size == 16) {
2342  // unrolling 2 times
2343  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2344  float* op = &out[rangeIndex * block_size];
2345  __m256 vop0 = _mm256_setzero_ps();
2346  __m256 vop8 = _mm256_setzero_ps();
2347  if (dataInd + lengths[rangeIndex] > index_size) {
2348  return false;
2349  }
2350  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2351  ++dataInd) {
2352  const int idx = indices[dataInd];
2353  if (idx < 0 || idx >= data_size) {
2354  return false;
2355  }
2356  float wgt = 1.f;
2357  float bio;
2358  if (weights) {
2359  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2360  }
2361  bio = wgt * scale_bias[2 * idx + 1];
2362  wgt = wgt * scale_bias[2 * idx];
2363  __m256 vbio = _mm256_set1_ps(bio);
2364  __m256 vwgt = _mm256_set1_ps(wgt);
2365  const uint8_t* ip = &input[idx * fused_block_size];
2366  const int next_T0 = (dataInd < index_size - prefdist_T0)
2367  ? (dataInd + prefdist_T0)
2368  : dataInd;
2369  const int idx_pref_T0 = indices[next_T0];
2370  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2371  return false;
2372  }
2373  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2374  vop0 = _mm256_fmadd_ps(
2375  vwgt,
2376  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2377  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2378  _mm256_add_ps(vop0, vbio));
2379  _mm_prefetch(
2380  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2381  vop8 = _mm256_fmadd_ps(
2382  vwgt,
2383  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2384  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2385  _mm256_add_ps(vop8, vbio));
2386  // skip unnecessary prefetch of (&ip_next_T0[8])
2387  }
2388  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2389  _mm256_storeu_ps(&op[0], vop0);
2390  _mm256_storeu_ps(&op[8], vop8);
2391  } else {
2392  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2393  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2394  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2395  }
2396  }
2397  } else {
2398  // generic code
2399  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2400  float* op = &out[rangeIndex * block_size];
2401  int64_t j = 0;
2402  for (; j + 8 <= block_size; j += 8) {
2403  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2404  }
2405  for (; j < block_size; j++) {
2406  op[j] = 0.0f;
2407  }
2408  if (dataInd + lengths[rangeIndex] > index_size) {
2409  return false;
2410  }
2411  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2412  ++dataInd) {
2413  const int idx = indices[dataInd];
2414  if (idx < 0 || idx >= data_size) {
2415  return false;
2416  }
2417  float wgt = 1.f;
2418  float bio;
2419  if (weights) {
2420  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2421  }
2422  bio = wgt * scale_bias[2 * idx + 1];
2423  wgt = wgt * scale_bias[2 * idx];
2424  __m256 vbio = _mm256_set1_ps(bio);
2425  __m256 vwgt = _mm256_set1_ps(wgt);
2426  const uint8_t* ip = &input[idx * fused_block_size];
2427  const int next_T0 = (dataInd < index_size - prefdist_T0)
2428  ? (dataInd + prefdist_T0)
2429  : dataInd;
2430  const int idx_pref_T0 = indices[next_T0];
2431  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2432  return false;
2433  }
2434  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2435  j = 0;
2436  for (; j + 8 <= block_size; j += 8) {
2437  _mm256_storeu_ps(
2438  &op[j],
2439  _mm256_fmadd_ps(
2440  vwgt,
2441  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2442  reinterpret_cast<const __m128i*>(&ip[j])))),
2443  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2444  _mm_prefetch(
2445  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2446  }
2447  for (; j < block_size; j++) {
2448  op[j] += wgt * ((float)ip[j]) + bio;
2449  }
2450  }
2451  if (normalize_by_lengths && lengths[rangeIndex]) {
2452  float len_inv = 1.0f / lengths[rangeIndex];
2453  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2454  j = 0;
2455  for (; j + 8 <= block_size; j += 8) {
2456  _mm256_storeu_ps(
2457  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2458  }
2459  for (; j < block_size; j++) {
2460  op[j] = len_inv * op[j];
2461  }
2462  }
2463  }
2464  }
2465  return dataInd == index_size;
2466 }
2467 bool EmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2468  const int64_t block_size,
2469  const int64_t output_size,
2470  const int64_t index_size,
2471  const int64_t data_size,
2472  const uint8_t* input,
2473  const int* indices,
2474  const int* lengths,
2475  const float* weights,
2476  const float* scale_bias,
2477  bool normalize_by_lengths,
2478  float* out) {
2479  return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2480  block_size,
2481  output_size,
2482  index_size,
2483  data_size,
2484  input,
2485  indices,
2486  lengths,
2487  weights,
2488  scale_bias,
2489  normalize_by_lengths,
2490  out);
2491 }
2492 bool EmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2493  const int64_t block_size,
2494  const int64_t output_size,
2495  const int64_t index_size,
2496  const int64_t data_size,
2497  const uint8_t* input,
2498  const int* indices,
2499  const int* lengths,
2500  const float* weights,
2501  const float* scale_bias,
2502  bool normalize_by_lengths,
2503  float* out) {
2504  return EmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2505  block_size,
2506  output_size,
2507  index_size,
2508  data_size,
2509  input,
2510  indices,
2511  lengths,
2512  weights,
2513  scale_bias,
2514  normalize_by_lengths,
2515  out);
2516 }
2517 
2518 template <bool IS_WEIGHT_POSITIONAL>
2519 static bool EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2520  const int64_t block_size,
2521  const int64_t output_size,
2522  const int64_t index_size,
2523  const int64_t data_size,
2524  const uint8_t* input,
2525  const int64_t* indices,
2526  const int* lengths,
2527  const float* weights,
2528  const float* scale_bias,
2529  bool normalize_by_lengths,
2530  float* out) {
2531  const int64_t prefdist_T0 = 16;
2532  const int64_t fused_block_size = block_size + 0;
2533  int64_t dataInd = 0;
2534  if (block_size == 128) {
2535  // unrolling 16 times
2536  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2537  float* op = &out[rangeIndex * block_size];
2538  __m256 vop0 = _mm256_setzero_ps();
2539  __m256 vop8 = _mm256_setzero_ps();
2540  __m256 vop16 = _mm256_setzero_ps();
2541  __m256 vop24 = _mm256_setzero_ps();
2542  __m256 vop32 = _mm256_setzero_ps();
2543  __m256 vop40 = _mm256_setzero_ps();
2544  __m256 vop48 = _mm256_setzero_ps();
2545  __m256 vop56 = _mm256_setzero_ps();
2546  __m256 vop64 = _mm256_setzero_ps();
2547  __m256 vop72 = _mm256_setzero_ps();
2548  __m256 vop80 = _mm256_setzero_ps();
2549  __m256 vop88 = _mm256_setzero_ps();
2550  __m256 vop96 = _mm256_setzero_ps();
2551  __m256 vop104 = _mm256_setzero_ps();
2552  __m256 vop112 = _mm256_setzero_ps();
2553  __m256 vop120 = _mm256_setzero_ps();
2554  if (dataInd + lengths[rangeIndex] > index_size) {
2555  return false;
2556  }
2557  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2558  ++dataInd) {
2559  const int64_t idx = indices[dataInd];
2560  if (idx < 0 || idx >= data_size) {
2561  return false;
2562  }
2563  float wgt = 1.f;
2564  float bio;
2565  if (weights) {
2566  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2567  }
2568  bio = wgt * scale_bias[2 * idx + 1];
2569  wgt = wgt * scale_bias[2 * idx];
2570  __m256 vbio = _mm256_set1_ps(bio);
2571  __m256 vwgt = _mm256_set1_ps(wgt);
2572  const uint8_t* ip = &input[idx * fused_block_size];
2573  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2574  ? (dataInd + prefdist_T0)
2575  : dataInd;
2576  const int64_t idx_pref_T0 = indices[next_T0];
2577  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2578  return false;
2579  }
2580  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2581  vop0 = _mm256_fmadd_ps(
2582  vwgt,
2583  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2584  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2585  _mm256_add_ps(vop0, vbio));
2586  _mm_prefetch(
2587  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2588  vop8 = _mm256_fmadd_ps(
2589  vwgt,
2590  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2591  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2592  _mm256_add_ps(vop8, vbio));
2593  // skip unnecessary prefetch of (&ip_next_T0[8])
2594  vop16 = _mm256_fmadd_ps(
2595  vwgt,
2596  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2597  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2598  _mm256_add_ps(vop16, vbio));
2599  // skip unnecessary prefetch of (&ip_next_T0[16])
2600  vop24 = _mm256_fmadd_ps(
2601  vwgt,
2602  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2603  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2604  _mm256_add_ps(vop24, vbio));
2605  // skip unnecessary prefetch of (&ip_next_T0[24])
2606  vop32 = _mm256_fmadd_ps(
2607  vwgt,
2608  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2609  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2610  _mm256_add_ps(vop32, vbio));
2611  // skip unnecessary prefetch of (&ip_next_T0[32])
2612  vop40 = _mm256_fmadd_ps(
2613  vwgt,
2614  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2615  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2616  _mm256_add_ps(vop40, vbio));
2617  // skip unnecessary prefetch of (&ip_next_T0[40])
2618  vop48 = _mm256_fmadd_ps(
2619  vwgt,
2620  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2621  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2622  _mm256_add_ps(vop48, vbio));
2623  // skip unnecessary prefetch of (&ip_next_T0[48])
2624  vop56 = _mm256_fmadd_ps(
2625  vwgt,
2626  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2627  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2628  _mm256_add_ps(vop56, vbio));
2629  // skip unnecessary prefetch of (&ip_next_T0[56])
2630  vop64 = _mm256_fmadd_ps(
2631  vwgt,
2632  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2633  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2634  _mm256_add_ps(vop64, vbio));
2635  _mm_prefetch(
2636  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2637  vop72 = _mm256_fmadd_ps(
2638  vwgt,
2639  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2640  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2641  _mm256_add_ps(vop72, vbio));
2642  // skip unnecessary prefetch of (&ip_next_T0[72])
2643  vop80 = _mm256_fmadd_ps(
2644  vwgt,
2645  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2646  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2647  _mm256_add_ps(vop80, vbio));
2648  // skip unnecessary prefetch of (&ip_next_T0[80])
2649  vop88 = _mm256_fmadd_ps(
2650  vwgt,
2651  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2652  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2653  _mm256_add_ps(vop88, vbio));
2654  // skip unnecessary prefetch of (&ip_next_T0[88])
2655  vop96 = _mm256_fmadd_ps(
2656  vwgt,
2657  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2658  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2659  _mm256_add_ps(vop96, vbio));
2660  // skip unnecessary prefetch of (&ip_next_T0[96])
2661  vop104 = _mm256_fmadd_ps(
2662  vwgt,
2663  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2664  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2665  _mm256_add_ps(vop104, vbio));
2666  // skip unnecessary prefetch of (&ip_next_T0[104])
2667  vop112 = _mm256_fmadd_ps(
2668  vwgt,
2669  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2670  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2671  _mm256_add_ps(vop112, vbio));
2672  // skip unnecessary prefetch of (&ip_next_T0[112])
2673  vop120 = _mm256_fmadd_ps(
2674  vwgt,
2675  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2676  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2677  _mm256_add_ps(vop120, vbio));
2678  // skip unnecessary prefetch of (&ip_next_T0[120])
2679  }
2680  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2681  _mm256_storeu_ps(&op[0], vop0);
2682  _mm256_storeu_ps(&op[8], vop8);
2683  _mm256_storeu_ps(&op[16], vop16);
2684  _mm256_storeu_ps(&op[24], vop24);
2685  _mm256_storeu_ps(&op[32], vop32);
2686  _mm256_storeu_ps(&op[40], vop40);
2687  _mm256_storeu_ps(&op[48], vop48);
2688  _mm256_storeu_ps(&op[56], vop56);
2689  _mm256_storeu_ps(&op[64], vop64);
2690  _mm256_storeu_ps(&op[72], vop72);
2691  _mm256_storeu_ps(&op[80], vop80);
2692  _mm256_storeu_ps(&op[88], vop88);
2693  _mm256_storeu_ps(&op[96], vop96);
2694  _mm256_storeu_ps(&op[104], vop104);
2695  _mm256_storeu_ps(&op[112], vop112);
2696  _mm256_storeu_ps(&op[120], vop120);
2697  } else {
2698  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2699  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2700  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2701  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2702  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2703  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2704  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2705  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2706  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2707  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2708  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2709  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2710  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2711  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2712  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2713  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2714  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2715  }
2716  }
2717  } else if (block_size == 64) {
2718  // unrolling 8 times
2719  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2720  float* op = &out[rangeIndex * block_size];
2721  __m256 vop0 = _mm256_setzero_ps();
2722  __m256 vop8 = _mm256_setzero_ps();
2723  __m256 vop16 = _mm256_setzero_ps();
2724  __m256 vop24 = _mm256_setzero_ps();
2725  __m256 vop32 = _mm256_setzero_ps();
2726  __m256 vop40 = _mm256_setzero_ps();
2727  __m256 vop48 = _mm256_setzero_ps();
2728  __m256 vop56 = _mm256_setzero_ps();
2729  if (dataInd + lengths[rangeIndex] > index_size) {
2730  return false;
2731  }
2732  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2733  ++dataInd) {
2734  const int64_t idx = indices[dataInd];
2735  if (idx < 0 || idx >= data_size) {
2736  return false;
2737  }
2738  float wgt = 1.f;
2739  float bio;
2740  if (weights) {
2741  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2742  }
2743  bio = wgt * scale_bias[2 * idx + 1];
2744  wgt = wgt * scale_bias[2 * idx];
2745  __m256 vbio = _mm256_set1_ps(bio);
2746  __m256 vwgt = _mm256_set1_ps(wgt);
2747  const uint8_t* ip = &input[idx * fused_block_size];
2748  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2749  ? (dataInd + prefdist_T0)
2750  : dataInd;
2751  const int64_t idx_pref_T0 = indices[next_T0];
2752  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2753  return false;
2754  }
2755  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2756  vop0 = _mm256_fmadd_ps(
2757  vwgt,
2758  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2759  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2760  _mm256_add_ps(vop0, vbio));
2761  _mm_prefetch(
2762  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2763  vop8 = _mm256_fmadd_ps(
2764  vwgt,
2765  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2766  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2767  _mm256_add_ps(vop8, vbio));
2768  // skip unnecessary prefetch of (&ip_next_T0[8])
2769  vop16 = _mm256_fmadd_ps(
2770  vwgt,
2771  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2772  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2773  _mm256_add_ps(vop16, vbio));
2774  // skip unnecessary prefetch of (&ip_next_T0[16])
2775  vop24 = _mm256_fmadd_ps(
2776  vwgt,
2777  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2778  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2779  _mm256_add_ps(vop24, vbio));
2780  // skip unnecessary prefetch of (&ip_next_T0[24])
2781  vop32 = _mm256_fmadd_ps(
2782  vwgt,
2783  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2784  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2785  _mm256_add_ps(vop32, vbio));
2786  // skip unnecessary prefetch of (&ip_next_T0[32])
2787  vop40 = _mm256_fmadd_ps(
2788  vwgt,
2789  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2790  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2791  _mm256_add_ps(vop40, vbio));
2792  // skip unnecessary prefetch of (&ip_next_T0[40])
2793  vop48 = _mm256_fmadd_ps(
2794  vwgt,
2795  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2796  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2797  _mm256_add_ps(vop48, vbio));
2798  // skip unnecessary prefetch of (&ip_next_T0[48])
2799  vop56 = _mm256_fmadd_ps(
2800  vwgt,
2801  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2802  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2803  _mm256_add_ps(vop56, vbio));
2804  // skip unnecessary prefetch of (&ip_next_T0[56])
2805  }
2806  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2807  _mm256_storeu_ps(&op[0], vop0);
2808  _mm256_storeu_ps(&op[8], vop8);
2809  _mm256_storeu_ps(&op[16], vop16);
2810  _mm256_storeu_ps(&op[24], vop24);
2811  _mm256_storeu_ps(&op[32], vop32);
2812  _mm256_storeu_ps(&op[40], vop40);
2813  _mm256_storeu_ps(&op[48], vop48);
2814  _mm256_storeu_ps(&op[56], vop56);
2815  } else {
2816  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2817  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2818  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2819  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2820  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2821  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2822  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2823  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2824  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2825  }
2826  }
2827  } else if (block_size == 32) {
2828  // unrolling 4 times
2829  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2830  float* op = &out[rangeIndex * block_size];
2831  __m256 vop0 = _mm256_setzero_ps();
2832  __m256 vop8 = _mm256_setzero_ps();
2833  __m256 vop16 = _mm256_setzero_ps();
2834  __m256 vop24 = _mm256_setzero_ps();
2835  if (dataInd + lengths[rangeIndex] > index_size) {
2836  return false;
2837  }
2838  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2839  ++dataInd) {
2840  const int64_t idx = indices[dataInd];
2841  if (idx < 0 || idx >= data_size) {
2842  return false;
2843  }
2844  float wgt = 1.f;
2845  float bio;
2846  if (weights) {
2847  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2848  }
2849  bio = wgt * scale_bias[2 * idx + 1];
2850  wgt = wgt * scale_bias[2 * idx];
2851  __m256 vbio = _mm256_set1_ps(bio);
2852  __m256 vwgt = _mm256_set1_ps(wgt);
2853  const uint8_t* ip = &input[idx * fused_block_size];
2854  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2855  ? (dataInd + prefdist_T0)
2856  : dataInd;
2857  const int64_t idx_pref_T0 = indices[next_T0];
2858  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2859  return false;
2860  }
2861  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2862  vop0 = _mm256_fmadd_ps(
2863  vwgt,
2864  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2865  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2866  _mm256_add_ps(vop0, vbio));
2867  _mm_prefetch(
2868  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2869  vop8 = _mm256_fmadd_ps(
2870  vwgt,
2871  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2872  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2873  _mm256_add_ps(vop8, vbio));
2874  // skip unnecessary prefetch of (&ip_next_T0[8])
2875  vop16 = _mm256_fmadd_ps(
2876  vwgt,
2877  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2878  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2879  _mm256_add_ps(vop16, vbio));
2880  // skip unnecessary prefetch of (&ip_next_T0[16])
2881  vop24 = _mm256_fmadd_ps(
2882  vwgt,
2883  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2884  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2885  _mm256_add_ps(vop24, vbio));
2886  // skip unnecessary prefetch of (&ip_next_T0[24])
2887  }
2888  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2889  _mm256_storeu_ps(&op[0], vop0);
2890  _mm256_storeu_ps(&op[8], vop8);
2891  _mm256_storeu_ps(&op[16], vop16);
2892  _mm256_storeu_ps(&op[24], vop24);
2893  } else {
2894  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2895  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2896  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2897  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2898  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2899  }
2900  }
2901  } else if (block_size == 16) {
2902  // unrolling 2 times
2903  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2904  float* op = &out[rangeIndex * block_size];
2905  __m256 vop0 = _mm256_setzero_ps();
2906  __m256 vop8 = _mm256_setzero_ps();
2907  if (dataInd + lengths[rangeIndex] > index_size) {
2908  return false;
2909  }
2910  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2911  ++dataInd) {
2912  const int64_t idx = indices[dataInd];
2913  if (idx < 0 || idx >= data_size) {
2914  return false;
2915  }
2916  float wgt = 1.f;
2917  float bio;
2918  if (weights) {
2919  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2920  }
2921  bio = wgt * scale_bias[2 * idx + 1];
2922  wgt = wgt * scale_bias[2 * idx];
2923  __m256 vbio = _mm256_set1_ps(bio);
2924  __m256 vwgt = _mm256_set1_ps(wgt);
2925  const uint8_t* ip = &input[idx * fused_block_size];
2926  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2927  ? (dataInd + prefdist_T0)
2928  : dataInd;
2929  const int64_t idx_pref_T0 = indices[next_T0];
2930  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2931  return false;
2932  }
2933  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2934  vop0 = _mm256_fmadd_ps(
2935  vwgt,
2936  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2937  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2938  _mm256_add_ps(vop0, vbio));
2939  _mm_prefetch(
2940  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2941  vop8 = _mm256_fmadd_ps(
2942  vwgt,
2943  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2944  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2945  _mm256_add_ps(vop8, vbio));
2946  // skip unnecessary prefetch of (&ip_next_T0[8])
2947  }
2948  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2949  _mm256_storeu_ps(&op[0], vop0);
2950  _mm256_storeu_ps(&op[8], vop8);
2951  } else {
2952  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2953  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2954  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2955  }
2956  }
2957  } else {
2958  // generic code
2959  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2960  float* op = &out[rangeIndex * block_size];
2961  int64_t j = 0;
2962  for (; j + 8 <= block_size; j += 8) {
2963  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2964  }
2965  for (; j < block_size; j++) {
2966  op[j] = 0.0f;
2967  }
2968  if (dataInd + lengths[rangeIndex] > index_size) {
2969  return false;
2970  }
2971  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2972  ++dataInd) {
2973  const int64_t idx = indices[dataInd];
2974  if (idx < 0 || idx >= data_size) {
2975  return false;
2976  }
2977  float wgt = 1.f;
2978  float bio;
2979  if (weights) {
2980  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2981  }
2982  bio = wgt * scale_bias[2 * idx + 1];
2983  wgt = wgt * scale_bias[2 * idx];
2984  __m256 vbio = _mm256_set1_ps(bio);
2985  __m256 vwgt = _mm256_set1_ps(wgt);
2986  const uint8_t* ip = &input[idx * fused_block_size];
2987  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2988  ? (dataInd + prefdist_T0)
2989  : dataInd;
2990  const int64_t idx_pref_T0 = indices[next_T0];
2991  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2992  return false;
2993  }
2994  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2995  j = 0;
2996  for (; j + 8 <= block_size; j += 8) {
2997  _mm256_storeu_ps(
2998  &op[j],
2999  _mm256_fmadd_ps(
3000  vwgt,
3001  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
3002  reinterpret_cast<const __m128i*>(&ip[j])))),
3003  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
3004  _mm_prefetch(
3005  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3006  }
3007  for (; j < block_size; j++) {
3008  op[j] += wgt * ((float)ip[j]) + bio;
3009  }
3010  }
3011  if (normalize_by_lengths && lengths[rangeIndex]) {
3012  float len_inv = 1.0f / lengths[rangeIndex];
3013  __m256 vlen_inv = _mm256_set1_ps(len_inv);
3014  j = 0;
3015  for (; j + 8 <= block_size; j += 8) {
3016  _mm256_storeu_ps(
3017  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3018  }
3019  for (; j < block_size; j++) {
3020  op[j] = len_inv * op[j];
3021  }
3022  }
3023  }
3024  }
3025  return dataInd == index_size;
3026 }
3027 bool EmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
3028  const int64_t block_size,
3029  const int64_t output_size,
3030  const int64_t index_size,
3031  const int64_t data_size,
3032  const uint8_t* input,
3033  const int64_t* indices,
3034  const int* lengths,
3035  const float* weights,
3036  const float* scale_bias,
3037  bool normalize_by_lengths,
3038  float* out) {
3039  return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3040  block_size,
3041  output_size,
3042  index_size,
3043  data_size,
3044  input,
3045  indices,
3046  lengths,
3047  weights,
3048  scale_bias,
3049  normalize_by_lengths,
3050  out);
3051 }
3052 bool EmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3053  const int64_t block_size,
3054  const int64_t output_size,
3055  const int64_t index_size,
3056  const int64_t data_size,
3057  const uint8_t* input,
3058  const int64_t* indices,
3059  const int* lengths,
3060  const float* weights,
3061  const float* scale_bias,
3062  bool normalize_by_lengths,
3063  float* out) {
3064  return EmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3065  block_size,
3066  output_size,
3067  index_size,
3068  data_size,
3069  input,
3070  indices,
3071  lengths,
3072  weights,
3073  scale_bias,
3074  normalize_by_lengths,
3075  out);
3076 }
3077 
3078 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13