Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_fused_8bit_rowwise_avx2.cc
1 
8 #include <c10/util/Half.h>
9 #include <immintrin.h>
10 namespace caffe2 {
11 
12 template <bool IS_WEIGHT_POSITIONAL>
13 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
14  const int64_t block_size,
15  const int64_t output_size,
16  const int64_t index_size,
17  const int64_t data_size,
18  const float* input,
19  const int* indices,
20  const int* lengths,
21  const float* weights,
22  bool normalize_by_lengths,
23  float* out) {
24  const int prefdist_T0 = 16;
25  const int fused_block_size = block_size + 2;
26  int dataInd = 0;
27  if (block_size == 128) {
28  // unrolling 16 times
29  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
30  float* op = &out[rangeIndex * block_size];
31  __m256 vop0 = _mm256_setzero_ps();
32  __m256 vop8 = _mm256_setzero_ps();
33  __m256 vop16 = _mm256_setzero_ps();
34  __m256 vop24 = _mm256_setzero_ps();
35  __m256 vop32 = _mm256_setzero_ps();
36  __m256 vop40 = _mm256_setzero_ps();
37  __m256 vop48 = _mm256_setzero_ps();
38  __m256 vop56 = _mm256_setzero_ps();
39  __m256 vop64 = _mm256_setzero_ps();
40  __m256 vop72 = _mm256_setzero_ps();
41  __m256 vop80 = _mm256_setzero_ps();
42  __m256 vop88 = _mm256_setzero_ps();
43  __m256 vop96 = _mm256_setzero_ps();
44  __m256 vop104 = _mm256_setzero_ps();
45  __m256 vop112 = _mm256_setzero_ps();
46  __m256 vop120 = _mm256_setzero_ps();
47  if (dataInd + lengths[rangeIndex] > index_size) {
48  return false;
49  }
50  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
51  ++dataInd) {
52  const int idx = indices[dataInd];
53  if (idx < 0 || idx >= data_size) {
54  return false;
55  }
56  float wgt = 1.f;
57  if (weights) {
58  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
59  }
60  __m256 vwgt = _mm256_set1_ps(wgt);
61  const float* ip = &input[idx * fused_block_size];
62  const int next_T0 = (dataInd < index_size - prefdist_T0)
63  ? (dataInd + prefdist_T0)
64  : dataInd;
65  const int idx_pref_T0 = indices[next_T0];
66  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
67  return false;
68  }
69  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
70  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
71  _mm_prefetch(
72  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
73  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
74  // skip unnecessary prefetch of (&ip_next_T0[8])
75  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
76  _mm_prefetch(
77  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
78  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
79  // skip unnecessary prefetch of (&ip_next_T0[24])
80  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
81  _mm_prefetch(
82  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
83  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
84  // skip unnecessary prefetch of (&ip_next_T0[40])
85  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
86  _mm_prefetch(
87  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
88  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
89  // skip unnecessary prefetch of (&ip_next_T0[56])
90  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
91  _mm_prefetch(
92  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
93  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
94  // skip unnecessary prefetch of (&ip_next_T0[72])
95  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
96  _mm_prefetch(
97  reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
98  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
99  // skip unnecessary prefetch of (&ip_next_T0[88])
100  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
101  _mm_prefetch(
102  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
103  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
104  // skip unnecessary prefetch of (&ip_next_T0[104])
105  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
106  _mm_prefetch(
107  reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
108  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
109  // skip unnecessary prefetch of (&ip_next_T0[120])
110  }
111  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
112  _mm256_storeu_ps(&op[0], vop0);
113  _mm256_storeu_ps(&op[8], vop8);
114  _mm256_storeu_ps(&op[16], vop16);
115  _mm256_storeu_ps(&op[24], vop24);
116  _mm256_storeu_ps(&op[32], vop32);
117  _mm256_storeu_ps(&op[40], vop40);
118  _mm256_storeu_ps(&op[48], vop48);
119  _mm256_storeu_ps(&op[56], vop56);
120  _mm256_storeu_ps(&op[64], vop64);
121  _mm256_storeu_ps(&op[72], vop72);
122  _mm256_storeu_ps(&op[80], vop80);
123  _mm256_storeu_ps(&op[88], vop88);
124  _mm256_storeu_ps(&op[96], vop96);
125  _mm256_storeu_ps(&op[104], vop104);
126  _mm256_storeu_ps(&op[112], vop112);
127  _mm256_storeu_ps(&op[120], vop120);
128  } else {
129  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
130  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
131  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
132  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
133  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
134  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
135  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
136  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
137  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
138  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
139  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
140  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
141  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
142  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
143  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
144  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
145  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
146  }
147  }
148  } else if (block_size == 64) {
149  // unrolling 8 times
150  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
151  float* op = &out[rangeIndex * block_size];
152  __m256 vop0 = _mm256_setzero_ps();
153  __m256 vop8 = _mm256_setzero_ps();
154  __m256 vop16 = _mm256_setzero_ps();
155  __m256 vop24 = _mm256_setzero_ps();
156  __m256 vop32 = _mm256_setzero_ps();
157  __m256 vop40 = _mm256_setzero_ps();
158  __m256 vop48 = _mm256_setzero_ps();
159  __m256 vop56 = _mm256_setzero_ps();
160  if (dataInd + lengths[rangeIndex] > index_size) {
161  return false;
162  }
163  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
164  ++dataInd) {
165  const int idx = indices[dataInd];
166  if (idx < 0 || idx >= data_size) {
167  return false;
168  }
169  float wgt = 1.f;
170  if (weights) {
171  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
172  }
173  __m256 vwgt = _mm256_set1_ps(wgt);
174  const float* ip = &input[idx * fused_block_size];
175  const int next_T0 = (dataInd < index_size - prefdist_T0)
176  ? (dataInd + prefdist_T0)
177  : dataInd;
178  const int idx_pref_T0 = indices[next_T0];
179  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
180  return false;
181  }
182  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
183  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
184  _mm_prefetch(
185  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
186  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
187  // skip unnecessary prefetch of (&ip_next_T0[8])
188  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
189  _mm_prefetch(
190  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
191  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
192  // skip unnecessary prefetch of (&ip_next_T0[24])
193  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
194  _mm_prefetch(
195  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
196  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
197  // skip unnecessary prefetch of (&ip_next_T0[40])
198  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
199  _mm_prefetch(
200  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
201  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
202  // skip unnecessary prefetch of (&ip_next_T0[56])
203  }
204  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
205  _mm256_storeu_ps(&op[0], vop0);
206  _mm256_storeu_ps(&op[8], vop8);
207  _mm256_storeu_ps(&op[16], vop16);
208  _mm256_storeu_ps(&op[24], vop24);
209  _mm256_storeu_ps(&op[32], vop32);
210  _mm256_storeu_ps(&op[40], vop40);
211  _mm256_storeu_ps(&op[48], vop48);
212  _mm256_storeu_ps(&op[56], vop56);
213  } else {
214  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
215  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
216  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
217  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
218  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
219  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
220  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
221  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
222  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
223  }
224  }
225  } else if (block_size == 32) {
226  // unrolling 4 times
227  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
228  float* op = &out[rangeIndex * block_size];
229  __m256 vop0 = _mm256_setzero_ps();
230  __m256 vop8 = _mm256_setzero_ps();
231  __m256 vop16 = _mm256_setzero_ps();
232  __m256 vop24 = _mm256_setzero_ps();
233  if (dataInd + lengths[rangeIndex] > index_size) {
234  return false;
235  }
236  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
237  ++dataInd) {
238  const int idx = indices[dataInd];
239  if (idx < 0 || idx >= data_size) {
240  return false;
241  }
242  float wgt = 1.f;
243  if (weights) {
244  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
245  }
246  __m256 vwgt = _mm256_set1_ps(wgt);
247  const float* ip = &input[idx * fused_block_size];
248  const int next_T0 = (dataInd < index_size - prefdist_T0)
249  ? (dataInd + prefdist_T0)
250  : dataInd;
251  const int idx_pref_T0 = indices[next_T0];
252  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
253  return false;
254  }
255  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
256  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
257  _mm_prefetch(
258  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
259  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
260  // skip unnecessary prefetch of (&ip_next_T0[8])
261  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
262  _mm_prefetch(
263  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
264  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
265  // skip unnecessary prefetch of (&ip_next_T0[24])
266  }
267  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
268  _mm256_storeu_ps(&op[0], vop0);
269  _mm256_storeu_ps(&op[8], vop8);
270  _mm256_storeu_ps(&op[16], vop16);
271  _mm256_storeu_ps(&op[24], vop24);
272  } else {
273  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
274  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
275  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
276  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
277  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
278  }
279  }
280  } else if (block_size == 16) {
281  // unrolling 2 times
282  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
283  float* op = &out[rangeIndex * block_size];
284  __m256 vop0 = _mm256_setzero_ps();
285  __m256 vop8 = _mm256_setzero_ps();
286  if (dataInd + lengths[rangeIndex] > index_size) {
287  return false;
288  }
289  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
290  ++dataInd) {
291  const int idx = indices[dataInd];
292  if (idx < 0 || idx >= data_size) {
293  return false;
294  }
295  float wgt = 1.f;
296  if (weights) {
297  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
298  }
299  __m256 vwgt = _mm256_set1_ps(wgt);
300  const float* ip = &input[idx * fused_block_size];
301  const int next_T0 = (dataInd < index_size - prefdist_T0)
302  ? (dataInd + prefdist_T0)
303  : dataInd;
304  const int idx_pref_T0 = indices[next_T0];
305  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
306  return false;
307  }
308  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
309  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
310  _mm_prefetch(
311  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
312  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
313  // skip unnecessary prefetch of (&ip_next_T0[8])
314  }
315  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
316  _mm256_storeu_ps(&op[0], vop0);
317  _mm256_storeu_ps(&op[8], vop8);
318  } else {
319  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
320  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
321  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
322  }
323  }
324  } else {
325  // generic code
326  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
327  float* op = &out[rangeIndex * block_size];
328  int64_t j = 0;
329  for (; j + 8 <= block_size; j += 8) {
330  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
331  }
332  for (; j < block_size; j++) {
333  op[j] = 0.0f;
334  }
335  if (dataInd + lengths[rangeIndex] > index_size) {
336  return false;
337  }
338  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
339  ++dataInd) {
340  const int idx = indices[dataInd];
341  if (idx < 0 || idx >= data_size) {
342  return false;
343  }
344  float wgt = 1.f;
345  if (weights) {
346  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
347  }
348  __m256 vwgt = _mm256_set1_ps(wgt);
349  const float* ip = &input[idx * fused_block_size];
350  const int next_T0 = (dataInd < index_size - prefdist_T0)
351  ? (dataInd + prefdist_T0)
352  : dataInd;
353  const int idx_pref_T0 = indices[next_T0];
354  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
355  return false;
356  }
357  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
358  j = 0;
359  for (; j + 8 <= block_size; j += 8) {
360  _mm256_storeu_ps(
361  &op[j],
362  _mm256_fmadd_ps(
363  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
364  _mm_prefetch(
365  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
366  }
367  for (; j < block_size; j++) {
368  op[j] += wgt * ip[j];
369  }
370  }
371  if (normalize_by_lengths && lengths[rangeIndex]) {
372  float len_inv = 1.0f / lengths[rangeIndex];
373  __m256 vlen_inv = _mm256_set1_ps(len_inv);
374  j = 0;
375  for (; j + 8 <= block_size; j += 8) {
376  _mm256_storeu_ps(
377  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
378  }
379  for (; j < block_size; j++) {
380  op[j] = len_inv * op[j];
381  }
382  }
383  }
384  }
385  return dataInd == index_size;
386 }
387 bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_false__avx2_fma(
388  const int64_t block_size,
389  const int64_t output_size,
390  const int64_t index_size,
391  const int64_t data_size,
392  const float* input,
393  const int* indices,
394  const int* lengths,
395  const float* weights,
396  bool normalize_by_lengths,
397  float* out) {
398  return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<false>(
399  block_size,
400  output_size,
401  index_size,
402  data_size,
403  input,
404  indices,
405  lengths,
406  weights,
407  normalize_by_lengths,
408  out);
409 }
410 bool Fused8BitRowwiseEmbeddingLookup_int32_t_float_float_true__avx2_fma(
411  const int64_t block_size,
412  const int64_t output_size,
413  const int64_t index_size,
414  const int64_t data_size,
415  const float* input,
416  const int* indices,
417  const int* lengths,
418  const float* weights,
419  bool normalize_by_lengths,
420  float* out) {
421  return Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma<true>(
422  block_size,
423  output_size,
424  index_size,
425  data_size,
426  input,
427  indices,
428  lengths,
429  weights,
430  normalize_by_lengths,
431  out);
432 }
433 
434 template <bool IS_WEIGHT_POSITIONAL>
435 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
436  const int64_t block_size,
437  const int64_t output_size,
438  const int64_t index_size,
439  const int64_t data_size,
440  const float* input,
441  const int64_t* indices,
442  const int* lengths,
443  const float* weights,
444  bool normalize_by_lengths,
445  float* out) {
446  const int64_t prefdist_T0 = 16;
447  const int64_t fused_block_size = block_size + 2;
448  int64_t dataInd = 0;
449  if (block_size == 128) {
450  // unrolling 16 times
451  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
452  float* op = &out[rangeIndex * block_size];
453  __m256 vop0 = _mm256_setzero_ps();
454  __m256 vop8 = _mm256_setzero_ps();
455  __m256 vop16 = _mm256_setzero_ps();
456  __m256 vop24 = _mm256_setzero_ps();
457  __m256 vop32 = _mm256_setzero_ps();
458  __m256 vop40 = _mm256_setzero_ps();
459  __m256 vop48 = _mm256_setzero_ps();
460  __m256 vop56 = _mm256_setzero_ps();
461  __m256 vop64 = _mm256_setzero_ps();
462  __m256 vop72 = _mm256_setzero_ps();
463  __m256 vop80 = _mm256_setzero_ps();
464  __m256 vop88 = _mm256_setzero_ps();
465  __m256 vop96 = _mm256_setzero_ps();
466  __m256 vop104 = _mm256_setzero_ps();
467  __m256 vop112 = _mm256_setzero_ps();
468  __m256 vop120 = _mm256_setzero_ps();
469  if (dataInd + lengths[rangeIndex] > index_size) {
470  return false;
471  }
472  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
473  ++dataInd) {
474  const int64_t idx = indices[dataInd];
475  if (idx < 0 || idx >= data_size) {
476  return false;
477  }
478  float wgt = 1.f;
479  if (weights) {
480  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
481  }
482  __m256 vwgt = _mm256_set1_ps(wgt);
483  const float* ip = &input[idx * fused_block_size];
484  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
485  ? (dataInd + prefdist_T0)
486  : dataInd;
487  const int64_t idx_pref_T0 = indices[next_T0];
488  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
489  return false;
490  }
491  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
492  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
493  _mm_prefetch(
494  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
495  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
496  // skip unnecessary prefetch of (&ip_next_T0[8])
497  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
498  _mm_prefetch(
499  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
500  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
501  // skip unnecessary prefetch of (&ip_next_T0[24])
502  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
503  _mm_prefetch(
504  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
505  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
506  // skip unnecessary prefetch of (&ip_next_T0[40])
507  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
508  _mm_prefetch(
509  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
510  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
511  // skip unnecessary prefetch of (&ip_next_T0[56])
512  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
513  _mm_prefetch(
514  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
515  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
516  // skip unnecessary prefetch of (&ip_next_T0[72])
517  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
518  _mm_prefetch(
519  reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
520  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
521  // skip unnecessary prefetch of (&ip_next_T0[88])
522  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
523  _mm_prefetch(
524  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
525  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
526  // skip unnecessary prefetch of (&ip_next_T0[104])
527  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
528  _mm_prefetch(
529  reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
530  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
531  // skip unnecessary prefetch of (&ip_next_T0[120])
532  }
533  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
534  _mm256_storeu_ps(&op[0], vop0);
535  _mm256_storeu_ps(&op[8], vop8);
536  _mm256_storeu_ps(&op[16], vop16);
537  _mm256_storeu_ps(&op[24], vop24);
538  _mm256_storeu_ps(&op[32], vop32);
539  _mm256_storeu_ps(&op[40], vop40);
540  _mm256_storeu_ps(&op[48], vop48);
541  _mm256_storeu_ps(&op[56], vop56);
542  _mm256_storeu_ps(&op[64], vop64);
543  _mm256_storeu_ps(&op[72], vop72);
544  _mm256_storeu_ps(&op[80], vop80);
545  _mm256_storeu_ps(&op[88], vop88);
546  _mm256_storeu_ps(&op[96], vop96);
547  _mm256_storeu_ps(&op[104], vop104);
548  _mm256_storeu_ps(&op[112], vop112);
549  _mm256_storeu_ps(&op[120], vop120);
550  } else {
551  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
552  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
553  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
554  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
555  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
556  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
557  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
558  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
559  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
560  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
561  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
562  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
563  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
564  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
565  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
566  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
567  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
568  }
569  }
570  } else if (block_size == 64) {
571  // unrolling 8 times
572  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
573  float* op = &out[rangeIndex * block_size];
574  __m256 vop0 = _mm256_setzero_ps();
575  __m256 vop8 = _mm256_setzero_ps();
576  __m256 vop16 = _mm256_setzero_ps();
577  __m256 vop24 = _mm256_setzero_ps();
578  __m256 vop32 = _mm256_setzero_ps();
579  __m256 vop40 = _mm256_setzero_ps();
580  __m256 vop48 = _mm256_setzero_ps();
581  __m256 vop56 = _mm256_setzero_ps();
582  if (dataInd + lengths[rangeIndex] > index_size) {
583  return false;
584  }
585  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
586  ++dataInd) {
587  const int64_t idx = indices[dataInd];
588  if (idx < 0 || idx >= data_size) {
589  return false;
590  }
591  float wgt = 1.f;
592  if (weights) {
593  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
594  }
595  __m256 vwgt = _mm256_set1_ps(wgt);
596  const float* ip = &input[idx * fused_block_size];
597  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
598  ? (dataInd + prefdist_T0)
599  : dataInd;
600  const int64_t idx_pref_T0 = indices[next_T0];
601  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
602  return false;
603  }
604  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
605  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
606  _mm_prefetch(
607  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
608  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
609  // skip unnecessary prefetch of (&ip_next_T0[8])
610  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
611  _mm_prefetch(
612  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
613  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
614  // skip unnecessary prefetch of (&ip_next_T0[24])
615  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
616  _mm_prefetch(
617  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
618  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
619  // skip unnecessary prefetch of (&ip_next_T0[40])
620  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
621  _mm_prefetch(
622  reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
623  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
624  // skip unnecessary prefetch of (&ip_next_T0[56])
625  }
626  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
627  _mm256_storeu_ps(&op[0], vop0);
628  _mm256_storeu_ps(&op[8], vop8);
629  _mm256_storeu_ps(&op[16], vop16);
630  _mm256_storeu_ps(&op[24], vop24);
631  _mm256_storeu_ps(&op[32], vop32);
632  _mm256_storeu_ps(&op[40], vop40);
633  _mm256_storeu_ps(&op[48], vop48);
634  _mm256_storeu_ps(&op[56], vop56);
635  } else {
636  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
637  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
638  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
639  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
640  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
641  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
642  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
643  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
644  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
645  }
646  }
647  } else if (block_size == 32) {
648  // unrolling 4 times
649  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
650  float* op = &out[rangeIndex * block_size];
651  __m256 vop0 = _mm256_setzero_ps();
652  __m256 vop8 = _mm256_setzero_ps();
653  __m256 vop16 = _mm256_setzero_ps();
654  __m256 vop24 = _mm256_setzero_ps();
655  if (dataInd + lengths[rangeIndex] > index_size) {
656  return false;
657  }
658  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
659  ++dataInd) {
660  const int64_t idx = indices[dataInd];
661  if (idx < 0 || idx >= data_size) {
662  return false;
663  }
664  float wgt = 1.f;
665  if (weights) {
666  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
667  }
668  __m256 vwgt = _mm256_set1_ps(wgt);
669  const float* ip = &input[idx * fused_block_size];
670  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
671  ? (dataInd + prefdist_T0)
672  : dataInd;
673  const int64_t idx_pref_T0 = indices[next_T0];
674  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
675  return false;
676  }
677  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
678  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
679  _mm_prefetch(
680  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
681  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
682  // skip unnecessary prefetch of (&ip_next_T0[8])
683  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
684  _mm_prefetch(
685  reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
686  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
687  // skip unnecessary prefetch of (&ip_next_T0[24])
688  }
689  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
690  _mm256_storeu_ps(&op[0], vop0);
691  _mm256_storeu_ps(&op[8], vop8);
692  _mm256_storeu_ps(&op[16], vop16);
693  _mm256_storeu_ps(&op[24], vop24);
694  } else {
695  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
696  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
697  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
698  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
699  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
700  }
701  }
702  } else if (block_size == 16) {
703  // unrolling 2 times
704  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
705  float* op = &out[rangeIndex * block_size];
706  __m256 vop0 = _mm256_setzero_ps();
707  __m256 vop8 = _mm256_setzero_ps();
708  if (dataInd + lengths[rangeIndex] > index_size) {
709  return false;
710  }
711  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
712  ++dataInd) {
713  const int64_t idx = indices[dataInd];
714  if (idx < 0 || idx >= data_size) {
715  return false;
716  }
717  float wgt = 1.f;
718  if (weights) {
719  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
720  }
721  __m256 vwgt = _mm256_set1_ps(wgt);
722  const float* ip = &input[idx * fused_block_size];
723  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
724  ? (dataInd + prefdist_T0)
725  : dataInd;
726  const int64_t idx_pref_T0 = indices[next_T0];
727  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
728  return false;
729  }
730  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
731  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
732  _mm_prefetch(
733  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
734  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
735  // skip unnecessary prefetch of (&ip_next_T0[8])
736  }
737  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
738  _mm256_storeu_ps(&op[0], vop0);
739  _mm256_storeu_ps(&op[8], vop8);
740  } else {
741  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
742  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
743  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
744  }
745  }
746  } else {
747  // generic code
748  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
749  float* op = &out[rangeIndex * block_size];
750  int64_t j = 0;
751  for (; j + 8 <= block_size; j += 8) {
752  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
753  }
754  for (; j < block_size; j++) {
755  op[j] = 0.0f;
756  }
757  if (dataInd + lengths[rangeIndex] > index_size) {
758  return false;
759  }
760  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
761  ++dataInd) {
762  const int64_t idx = indices[dataInd];
763  if (idx < 0 || idx >= data_size) {
764  return false;
765  }
766  float wgt = 1.f;
767  if (weights) {
768  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
769  }
770  __m256 vwgt = _mm256_set1_ps(wgt);
771  const float* ip = &input[idx * fused_block_size];
772  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
773  ? (dataInd + prefdist_T0)
774  : dataInd;
775  const int64_t idx_pref_T0 = indices[next_T0];
776  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
777  return false;
778  }
779  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
780  j = 0;
781  for (; j + 8 <= block_size; j += 8) {
782  _mm256_storeu_ps(
783  &op[j],
784  _mm256_fmadd_ps(
785  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
786  _mm_prefetch(
787  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
788  }
789  for (; j < block_size; j++) {
790  op[j] += wgt * ip[j];
791  }
792  }
793  if (normalize_by_lengths && lengths[rangeIndex]) {
794  float len_inv = 1.0f / lengths[rangeIndex];
795  __m256 vlen_inv = _mm256_set1_ps(len_inv);
796  j = 0;
797  for (; j + 8 <= block_size; j += 8) {
798  _mm256_storeu_ps(
799  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
800  }
801  for (; j < block_size; j++) {
802  op[j] = len_inv * op[j];
803  }
804  }
805  }
806  }
807  return dataInd == index_size;
808 }
809 bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_false__avx2_fma(
810  const int64_t block_size,
811  const int64_t output_size,
812  const int64_t index_size,
813  const int64_t data_size,
814  const float* input,
815  const int64_t* indices,
816  const int* lengths,
817  const float* weights,
818  bool normalize_by_lengths,
819  float* out) {
820  return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<false>(
821  block_size,
822  output_size,
823  index_size,
824  data_size,
825  input,
826  indices,
827  lengths,
828  weights,
829  normalize_by_lengths,
830  out);
831 }
832 bool Fused8BitRowwiseEmbeddingLookup_int64_t_float_float_true__avx2_fma(
833  const int64_t block_size,
834  const int64_t output_size,
835  const int64_t index_size,
836  const int64_t data_size,
837  const float* input,
838  const int64_t* indices,
839  const int* lengths,
840  const float* weights,
841  bool normalize_by_lengths,
842  float* out) {
843  return Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma<true>(
844  block_size,
845  output_size,
846  index_size,
847  data_size,
848  input,
849  indices,
850  lengths,
851  weights,
852  normalize_by_lengths,
853  out);
854 }
855 
856 template <bool IS_WEIGHT_POSITIONAL>
857 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma(
858  const int64_t block_size,
859  const int64_t output_size,
860  const int64_t index_size,
861  const int64_t data_size,
862  const at::Half* input,
863  const int* indices,
864  const int* lengths,
865  const float* weights,
866  bool normalize_by_lengths,
867  float* out) {
868  const int prefdist_T0 = 16;
869  const int fused_block_size = block_size + 4;
870  int dataInd = 0;
871  if (block_size == 128) {
872  // unrolling 16 times
873  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
874  float* op = &out[rangeIndex * block_size];
875  __m256 vop0 = _mm256_setzero_ps();
876  __m256 vop8 = _mm256_setzero_ps();
877  __m256 vop16 = _mm256_setzero_ps();
878  __m256 vop24 = _mm256_setzero_ps();
879  __m256 vop32 = _mm256_setzero_ps();
880  __m256 vop40 = _mm256_setzero_ps();
881  __m256 vop48 = _mm256_setzero_ps();
882  __m256 vop56 = _mm256_setzero_ps();
883  __m256 vop64 = _mm256_setzero_ps();
884  __m256 vop72 = _mm256_setzero_ps();
885  __m256 vop80 = _mm256_setzero_ps();
886  __m256 vop88 = _mm256_setzero_ps();
887  __m256 vop96 = _mm256_setzero_ps();
888  __m256 vop104 = _mm256_setzero_ps();
889  __m256 vop112 = _mm256_setzero_ps();
890  __m256 vop120 = _mm256_setzero_ps();
891  if (dataInd + lengths[rangeIndex] > index_size) {
892  return false;
893  }
894  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
895  ++dataInd) {
896  const int idx = indices[dataInd];
897  if (idx < 0 || idx >= data_size) {
898  return false;
899  }
900  float wgt = 1.f;
901  if (weights) {
902  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
903  }
904  __m256 vwgt = _mm256_set1_ps(wgt);
905  const at::Half* ip = &input[idx * fused_block_size];
906  const int next_T0 = (dataInd < index_size - prefdist_T0)
907  ? (dataInd + prefdist_T0)
908  : dataInd;
909  const int idx_pref_T0 = indices[next_T0];
910  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
911  return false;
912  }
913  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
914  vop0 = _mm256_fmadd_ps(
915  vwgt,
916  _mm256_cvtph_ps(
917  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
918  vop0);
919  _mm_prefetch(
920  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
921  vop8 = _mm256_fmadd_ps(
922  vwgt,
923  _mm256_cvtph_ps(
924  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
925  vop8);
926  // skip unnecessary prefetch of (&ip_next_T0[8])
927  vop16 = _mm256_fmadd_ps(
928  vwgt,
929  _mm256_cvtph_ps(
930  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
931  vop16);
932  // skip unnecessary prefetch of (&ip_next_T0[16])
933  vop24 = _mm256_fmadd_ps(
934  vwgt,
935  _mm256_cvtph_ps(
936  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
937  vop24);
938  // skip unnecessary prefetch of (&ip_next_T0[24])
939  vop32 = _mm256_fmadd_ps(
940  vwgt,
941  _mm256_cvtph_ps(
942  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
943  vop32);
944  _mm_prefetch(
945  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
946  vop40 = _mm256_fmadd_ps(
947  vwgt,
948  _mm256_cvtph_ps(
949  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
950  vop40);
951  // skip unnecessary prefetch of (&ip_next_T0[40])
952  vop48 = _mm256_fmadd_ps(
953  vwgt,
954  _mm256_cvtph_ps(
955  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
956  vop48);
957  // skip unnecessary prefetch of (&ip_next_T0[48])
958  vop56 = _mm256_fmadd_ps(
959  vwgt,
960  _mm256_cvtph_ps(
961  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
962  vop56);
963  // skip unnecessary prefetch of (&ip_next_T0[56])
964  vop64 = _mm256_fmadd_ps(
965  vwgt,
966  _mm256_cvtph_ps(
967  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
968  vop64);
969  _mm_prefetch(
970  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
971  vop72 = _mm256_fmadd_ps(
972  vwgt,
973  _mm256_cvtph_ps(
974  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
975  vop72);
976  // skip unnecessary prefetch of (&ip_next_T0[72])
977  vop80 = _mm256_fmadd_ps(
978  vwgt,
979  _mm256_cvtph_ps(
980  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
981  vop80);
982  // skip unnecessary prefetch of (&ip_next_T0[80])
983  vop88 = _mm256_fmadd_ps(
984  vwgt,
985  _mm256_cvtph_ps(
986  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
987  vop88);
988  // skip unnecessary prefetch of (&ip_next_T0[88])
989  vop96 = _mm256_fmadd_ps(
990  vwgt,
991  _mm256_cvtph_ps(
992  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
993  vop96);
994  _mm_prefetch(
995  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
996  vop104 = _mm256_fmadd_ps(
997  vwgt,
998  _mm256_cvtph_ps(
999  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1000  vop104);
1001  // skip unnecessary prefetch of (&ip_next_T0[104])
1002  vop112 = _mm256_fmadd_ps(
1003  vwgt,
1004  _mm256_cvtph_ps(
1005  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1006  vop112);
1007  // skip unnecessary prefetch of (&ip_next_T0[112])
1008  vop120 = _mm256_fmadd_ps(
1009  vwgt,
1010  _mm256_cvtph_ps(
1011  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1012  vop120);
1013  // skip unnecessary prefetch of (&ip_next_T0[120])
1014  }
1015  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1016  _mm256_storeu_ps(&op[0], vop0);
1017  _mm256_storeu_ps(&op[8], vop8);
1018  _mm256_storeu_ps(&op[16], vop16);
1019  _mm256_storeu_ps(&op[24], vop24);
1020  _mm256_storeu_ps(&op[32], vop32);
1021  _mm256_storeu_ps(&op[40], vop40);
1022  _mm256_storeu_ps(&op[48], vop48);
1023  _mm256_storeu_ps(&op[56], vop56);
1024  _mm256_storeu_ps(&op[64], vop64);
1025  _mm256_storeu_ps(&op[72], vop72);
1026  _mm256_storeu_ps(&op[80], vop80);
1027  _mm256_storeu_ps(&op[88], vop88);
1028  _mm256_storeu_ps(&op[96], vop96);
1029  _mm256_storeu_ps(&op[104], vop104);
1030  _mm256_storeu_ps(&op[112], vop112);
1031  _mm256_storeu_ps(&op[120], vop120);
1032  } else {
1033  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1034  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1035  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1036  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1037  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1038  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1039  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1040  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1041  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1042  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1043  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1044  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1045  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1046  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1047  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1048  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1049  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1050  }
1051  }
1052  } else if (block_size == 64) {
1053  // unrolling 8 times
1054  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1055  float* op = &out[rangeIndex * block_size];
1056  __m256 vop0 = _mm256_setzero_ps();
1057  __m256 vop8 = _mm256_setzero_ps();
1058  __m256 vop16 = _mm256_setzero_ps();
1059  __m256 vop24 = _mm256_setzero_ps();
1060  __m256 vop32 = _mm256_setzero_ps();
1061  __m256 vop40 = _mm256_setzero_ps();
1062  __m256 vop48 = _mm256_setzero_ps();
1063  __m256 vop56 = _mm256_setzero_ps();
1064  if (dataInd + lengths[rangeIndex] > index_size) {
1065  return false;
1066  }
1067  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1068  ++dataInd) {
1069  const int idx = indices[dataInd];
1070  if (idx < 0 || idx >= data_size) {
1071  return false;
1072  }
1073  float wgt = 1.f;
1074  if (weights) {
1075  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1076  }
1077  __m256 vwgt = _mm256_set1_ps(wgt);
1078  const at::Half* ip = &input[idx * fused_block_size];
1079  const int next_T0 = (dataInd < index_size - prefdist_T0)
1080  ? (dataInd + prefdist_T0)
1081  : dataInd;
1082  const int idx_pref_T0 = indices[next_T0];
1083  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1084  return false;
1085  }
1086  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1087  vop0 = _mm256_fmadd_ps(
1088  vwgt,
1089  _mm256_cvtph_ps(
1090  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1091  vop0);
1092  _mm_prefetch(
1093  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1094  vop8 = _mm256_fmadd_ps(
1095  vwgt,
1096  _mm256_cvtph_ps(
1097  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1098  vop8);
1099  // skip unnecessary prefetch of (&ip_next_T0[8])
1100  vop16 = _mm256_fmadd_ps(
1101  vwgt,
1102  _mm256_cvtph_ps(
1103  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1104  vop16);
1105  // skip unnecessary prefetch of (&ip_next_T0[16])
1106  vop24 = _mm256_fmadd_ps(
1107  vwgt,
1108  _mm256_cvtph_ps(
1109  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1110  vop24);
1111  // skip unnecessary prefetch of (&ip_next_T0[24])
1112  vop32 = _mm256_fmadd_ps(
1113  vwgt,
1114  _mm256_cvtph_ps(
1115  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1116  vop32);
1117  _mm_prefetch(
1118  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1119  vop40 = _mm256_fmadd_ps(
1120  vwgt,
1121  _mm256_cvtph_ps(
1122  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1123  vop40);
1124  // skip unnecessary prefetch of (&ip_next_T0[40])
1125  vop48 = _mm256_fmadd_ps(
1126  vwgt,
1127  _mm256_cvtph_ps(
1128  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1129  vop48);
1130  // skip unnecessary prefetch of (&ip_next_T0[48])
1131  vop56 = _mm256_fmadd_ps(
1132  vwgt,
1133  _mm256_cvtph_ps(
1134  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1135  vop56);
1136  // skip unnecessary prefetch of (&ip_next_T0[56])
1137  }
1138  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1139  _mm256_storeu_ps(&op[0], vop0);
1140  _mm256_storeu_ps(&op[8], vop8);
1141  _mm256_storeu_ps(&op[16], vop16);
1142  _mm256_storeu_ps(&op[24], vop24);
1143  _mm256_storeu_ps(&op[32], vop32);
1144  _mm256_storeu_ps(&op[40], vop40);
1145  _mm256_storeu_ps(&op[48], vop48);
1146  _mm256_storeu_ps(&op[56], vop56);
1147  } else {
1148  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1149  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1150  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1151  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1152  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1153  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1154  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1155  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1156  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1157  }
1158  }
1159  } else if (block_size == 32) {
1160  // unrolling 4 times
1161  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1162  float* op = &out[rangeIndex * block_size];
1163  __m256 vop0 = _mm256_setzero_ps();
1164  __m256 vop8 = _mm256_setzero_ps();
1165  __m256 vop16 = _mm256_setzero_ps();
1166  __m256 vop24 = _mm256_setzero_ps();
1167  if (dataInd + lengths[rangeIndex] > index_size) {
1168  return false;
1169  }
1170  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1171  ++dataInd) {
1172  const int idx = indices[dataInd];
1173  if (idx < 0 || idx >= data_size) {
1174  return false;
1175  }
1176  float wgt = 1.f;
1177  if (weights) {
1178  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1179  }
1180  __m256 vwgt = _mm256_set1_ps(wgt);
1181  const at::Half* ip = &input[idx * fused_block_size];
1182  const int next_T0 = (dataInd < index_size - prefdist_T0)
1183  ? (dataInd + prefdist_T0)
1184  : dataInd;
1185  const int idx_pref_T0 = indices[next_T0];
1186  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1187  return false;
1188  }
1189  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1190  vop0 = _mm256_fmadd_ps(
1191  vwgt,
1192  _mm256_cvtph_ps(
1193  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1194  vop0);
1195  _mm_prefetch(
1196  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1197  vop8 = _mm256_fmadd_ps(
1198  vwgt,
1199  _mm256_cvtph_ps(
1200  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1201  vop8);
1202  // skip unnecessary prefetch of (&ip_next_T0[8])
1203  vop16 = _mm256_fmadd_ps(
1204  vwgt,
1205  _mm256_cvtph_ps(
1206  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1207  vop16);
1208  // skip unnecessary prefetch of (&ip_next_T0[16])
1209  vop24 = _mm256_fmadd_ps(
1210  vwgt,
1211  _mm256_cvtph_ps(
1212  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1213  vop24);
1214  // skip unnecessary prefetch of (&ip_next_T0[24])
1215  }
1216  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1217  _mm256_storeu_ps(&op[0], vop0);
1218  _mm256_storeu_ps(&op[8], vop8);
1219  _mm256_storeu_ps(&op[16], vop16);
1220  _mm256_storeu_ps(&op[24], vop24);
1221  } else {
1222  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1223  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1224  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1225  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1226  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1227  }
1228  }
1229  } else if (block_size == 16) {
1230  // unrolling 2 times
1231  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1232  float* op = &out[rangeIndex * block_size];
1233  __m256 vop0 = _mm256_setzero_ps();
1234  __m256 vop8 = _mm256_setzero_ps();
1235  if (dataInd + lengths[rangeIndex] > index_size) {
1236  return false;
1237  }
1238  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1239  ++dataInd) {
1240  const int idx = indices[dataInd];
1241  if (idx < 0 || idx >= data_size) {
1242  return false;
1243  }
1244  float wgt = 1.f;
1245  if (weights) {
1246  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1247  }
1248  __m256 vwgt = _mm256_set1_ps(wgt);
1249  const at::Half* ip = &input[idx * fused_block_size];
1250  const int next_T0 = (dataInd < index_size - prefdist_T0)
1251  ? (dataInd + prefdist_T0)
1252  : dataInd;
1253  const int idx_pref_T0 = indices[next_T0];
1254  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1255  return false;
1256  }
1257  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1258  vop0 = _mm256_fmadd_ps(
1259  vwgt,
1260  _mm256_cvtph_ps(
1261  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1262  vop0);
1263  _mm_prefetch(
1264  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1265  vop8 = _mm256_fmadd_ps(
1266  vwgt,
1267  _mm256_cvtph_ps(
1268  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1269  vop8);
1270  // skip unnecessary prefetch of (&ip_next_T0[8])
1271  }
1272  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1273  _mm256_storeu_ps(&op[0], vop0);
1274  _mm256_storeu_ps(&op[8], vop8);
1275  } else {
1276  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1277  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1278  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1279  }
1280  }
1281  } else {
1282  // generic code
1283  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1284  float* op = &out[rangeIndex * block_size];
1285  int64_t j = 0;
1286  for (; j + 8 <= block_size; j += 8) {
1287  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1288  }
1289  for (; j < block_size; j++) {
1290  op[j] = 0.0f;
1291  }
1292  if (dataInd + lengths[rangeIndex] > index_size) {
1293  return false;
1294  }
1295  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1296  ++dataInd) {
1297  const int idx = indices[dataInd];
1298  if (idx < 0 || idx >= data_size) {
1299  return false;
1300  }
1301  float wgt = 1.f;
1302  if (weights) {
1303  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1304  }
1305  __m256 vwgt = _mm256_set1_ps(wgt);
1306  const at::Half* ip = &input[idx * fused_block_size];
1307  const int next_T0 = (dataInd < index_size - prefdist_T0)
1308  ? (dataInd + prefdist_T0)
1309  : dataInd;
1310  const int idx_pref_T0 = indices[next_T0];
1311  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1312  return false;
1313  }
1314  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1315  j = 0;
1316  for (; j + 8 <= block_size; j += 8) {
1317  _mm256_storeu_ps(
1318  &op[j],
1319  _mm256_fmadd_ps(
1320  vwgt,
1321  _mm256_cvtph_ps(_mm_loadu_si128(
1322  reinterpret_cast<const __m128i*>(&ip[j]))),
1323  _mm256_loadu_ps(&op[j])));
1324  _mm_prefetch(
1325  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1326  }
1327  alignas(64) at::Half vtmp1[8];
1328  for (; j < block_size; j++) {
1329  vtmp1[0] = ip[j];
1330  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1331  op[j] += wgt * ((float*)(&vtmp2))[0];
1332  }
1333  }
1334  if (normalize_by_lengths && lengths[rangeIndex]) {
1335  float len_inv = 1.0f / lengths[rangeIndex];
1336  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1337  j = 0;
1338  for (; j + 8 <= block_size; j += 8) {
1339  _mm256_storeu_ps(
1340  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1341  }
1342  for (; j < block_size; j++) {
1343  op[j] = len_inv * op[j];
1344  }
1345  }
1346  }
1347  }
1348  return dataInd == index_size;
1349 }
1350 bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_false__avx2_fma(
1351  const int64_t block_size,
1352  const int64_t output_size,
1353  const int64_t index_size,
1354  const int64_t data_size,
1355  const at::Half* input,
1356  const int* indices,
1357  const int* lengths,
1358  const float* weights,
1359  bool normalize_by_lengths,
1360  float* out) {
1361  return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma<false>(
1362  block_size,
1363  output_size,
1364  index_size,
1365  data_size,
1366  input,
1367  indices,
1368  lengths,
1369  weights,
1370  normalize_by_lengths,
1371  out);
1372 }
1373 bool Fused8BitRowwiseEmbeddingLookup_int32_t_half_float_true__avx2_fma(
1374  const int64_t block_size,
1375  const int64_t output_size,
1376  const int64_t index_size,
1377  const int64_t data_size,
1378  const at::Half* input,
1379  const int* indices,
1380  const int* lengths,
1381  const float* weights,
1382  bool normalize_by_lengths,
1383  float* out) {
1384  return Fused8BitRowwiseEmbeddingLookup_int32_t_half_float__avx2_fma<true>(
1385  block_size,
1386  output_size,
1387  index_size,
1388  data_size,
1389  input,
1390  indices,
1391  lengths,
1392  weights,
1393  normalize_by_lengths,
1394  out);
1395 }
1396 
1397 template <bool IS_WEIGHT_POSITIONAL>
1398 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma(
1399  const int64_t block_size,
1400  const int64_t output_size,
1401  const int64_t index_size,
1402  const int64_t data_size,
1403  const at::Half* input,
1404  const int64_t* indices,
1405  const int* lengths,
1406  const float* weights,
1407  bool normalize_by_lengths,
1408  float* out) {
1409  const int64_t prefdist_T0 = 16;
1410  const int64_t fused_block_size = block_size + 4;
1411  int64_t dataInd = 0;
1412  if (block_size == 128) {
1413  // unrolling 16 times
1414  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1415  float* op = &out[rangeIndex * block_size];
1416  __m256 vop0 = _mm256_setzero_ps();
1417  __m256 vop8 = _mm256_setzero_ps();
1418  __m256 vop16 = _mm256_setzero_ps();
1419  __m256 vop24 = _mm256_setzero_ps();
1420  __m256 vop32 = _mm256_setzero_ps();
1421  __m256 vop40 = _mm256_setzero_ps();
1422  __m256 vop48 = _mm256_setzero_ps();
1423  __m256 vop56 = _mm256_setzero_ps();
1424  __m256 vop64 = _mm256_setzero_ps();
1425  __m256 vop72 = _mm256_setzero_ps();
1426  __m256 vop80 = _mm256_setzero_ps();
1427  __m256 vop88 = _mm256_setzero_ps();
1428  __m256 vop96 = _mm256_setzero_ps();
1429  __m256 vop104 = _mm256_setzero_ps();
1430  __m256 vop112 = _mm256_setzero_ps();
1431  __m256 vop120 = _mm256_setzero_ps();
1432  if (dataInd + lengths[rangeIndex] > index_size) {
1433  return false;
1434  }
1435  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1436  ++dataInd) {
1437  const int64_t idx = indices[dataInd];
1438  if (idx < 0 || idx >= data_size) {
1439  return false;
1440  }
1441  float wgt = 1.f;
1442  if (weights) {
1443  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1444  }
1445  __m256 vwgt = _mm256_set1_ps(wgt);
1446  const at::Half* ip = &input[idx * fused_block_size];
1447  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1448  ? (dataInd + prefdist_T0)
1449  : dataInd;
1450  const int64_t idx_pref_T0 = indices[next_T0];
1451  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1452  return false;
1453  }
1454  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1455  vop0 = _mm256_fmadd_ps(
1456  vwgt,
1457  _mm256_cvtph_ps(
1458  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1459  vop0);
1460  _mm_prefetch(
1461  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1462  vop8 = _mm256_fmadd_ps(
1463  vwgt,
1464  _mm256_cvtph_ps(
1465  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1466  vop8);
1467  // skip unnecessary prefetch of (&ip_next_T0[8])
1468  vop16 = _mm256_fmadd_ps(
1469  vwgt,
1470  _mm256_cvtph_ps(
1471  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1472  vop16);
1473  // skip unnecessary prefetch of (&ip_next_T0[16])
1474  vop24 = _mm256_fmadd_ps(
1475  vwgt,
1476  _mm256_cvtph_ps(
1477  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1478  vop24);
1479  // skip unnecessary prefetch of (&ip_next_T0[24])
1480  vop32 = _mm256_fmadd_ps(
1481  vwgt,
1482  _mm256_cvtph_ps(
1483  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1484  vop32);
1485  _mm_prefetch(
1486  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1487  vop40 = _mm256_fmadd_ps(
1488  vwgt,
1489  _mm256_cvtph_ps(
1490  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1491  vop40);
1492  // skip unnecessary prefetch of (&ip_next_T0[40])
1493  vop48 = _mm256_fmadd_ps(
1494  vwgt,
1495  _mm256_cvtph_ps(
1496  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1497  vop48);
1498  // skip unnecessary prefetch of (&ip_next_T0[48])
1499  vop56 = _mm256_fmadd_ps(
1500  vwgt,
1501  _mm256_cvtph_ps(
1502  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1503  vop56);
1504  // skip unnecessary prefetch of (&ip_next_T0[56])
1505  vop64 = _mm256_fmadd_ps(
1506  vwgt,
1507  _mm256_cvtph_ps(
1508  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1509  vop64);
1510  _mm_prefetch(
1511  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1512  vop72 = _mm256_fmadd_ps(
1513  vwgt,
1514  _mm256_cvtph_ps(
1515  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1516  vop72);
1517  // skip unnecessary prefetch of (&ip_next_T0[72])
1518  vop80 = _mm256_fmadd_ps(
1519  vwgt,
1520  _mm256_cvtph_ps(
1521  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1522  vop80);
1523  // skip unnecessary prefetch of (&ip_next_T0[80])
1524  vop88 = _mm256_fmadd_ps(
1525  vwgt,
1526  _mm256_cvtph_ps(
1527  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1528  vop88);
1529  // skip unnecessary prefetch of (&ip_next_T0[88])
1530  vop96 = _mm256_fmadd_ps(
1531  vwgt,
1532  _mm256_cvtph_ps(
1533  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1534  vop96);
1535  _mm_prefetch(
1536  reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1537  vop104 = _mm256_fmadd_ps(
1538  vwgt,
1539  _mm256_cvtph_ps(
1540  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1541  vop104);
1542  // skip unnecessary prefetch of (&ip_next_T0[104])
1543  vop112 = _mm256_fmadd_ps(
1544  vwgt,
1545  _mm256_cvtph_ps(
1546  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1547  vop112);
1548  // skip unnecessary prefetch of (&ip_next_T0[112])
1549  vop120 = _mm256_fmadd_ps(
1550  vwgt,
1551  _mm256_cvtph_ps(
1552  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1553  vop120);
1554  // skip unnecessary prefetch of (&ip_next_T0[120])
1555  }
1556  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1557  _mm256_storeu_ps(&op[0], vop0);
1558  _mm256_storeu_ps(&op[8], vop8);
1559  _mm256_storeu_ps(&op[16], vop16);
1560  _mm256_storeu_ps(&op[24], vop24);
1561  _mm256_storeu_ps(&op[32], vop32);
1562  _mm256_storeu_ps(&op[40], vop40);
1563  _mm256_storeu_ps(&op[48], vop48);
1564  _mm256_storeu_ps(&op[56], vop56);
1565  _mm256_storeu_ps(&op[64], vop64);
1566  _mm256_storeu_ps(&op[72], vop72);
1567  _mm256_storeu_ps(&op[80], vop80);
1568  _mm256_storeu_ps(&op[88], vop88);
1569  _mm256_storeu_ps(&op[96], vop96);
1570  _mm256_storeu_ps(&op[104], vop104);
1571  _mm256_storeu_ps(&op[112], vop112);
1572  _mm256_storeu_ps(&op[120], vop120);
1573  } else {
1574  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1575  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1576  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1577  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1578  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1579  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1580  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1581  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1582  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1583  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1584  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1585  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1586  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1587  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1588  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1589  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1590  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1591  }
1592  }
1593  } else if (block_size == 64) {
1594  // unrolling 8 times
1595  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1596  float* op = &out[rangeIndex * block_size];
1597  __m256 vop0 = _mm256_setzero_ps();
1598  __m256 vop8 = _mm256_setzero_ps();
1599  __m256 vop16 = _mm256_setzero_ps();
1600  __m256 vop24 = _mm256_setzero_ps();
1601  __m256 vop32 = _mm256_setzero_ps();
1602  __m256 vop40 = _mm256_setzero_ps();
1603  __m256 vop48 = _mm256_setzero_ps();
1604  __m256 vop56 = _mm256_setzero_ps();
1605  if (dataInd + lengths[rangeIndex] > index_size) {
1606  return false;
1607  }
1608  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1609  ++dataInd) {
1610  const int64_t idx = indices[dataInd];
1611  if (idx < 0 || idx >= data_size) {
1612  return false;
1613  }
1614  float wgt = 1.f;
1615  if (weights) {
1616  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1617  }
1618  __m256 vwgt = _mm256_set1_ps(wgt);
1619  const at::Half* ip = &input[idx * fused_block_size];
1620  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1621  ? (dataInd + prefdist_T0)
1622  : dataInd;
1623  const int64_t idx_pref_T0 = indices[next_T0];
1624  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1625  return false;
1626  }
1627  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1628  vop0 = _mm256_fmadd_ps(
1629  vwgt,
1630  _mm256_cvtph_ps(
1631  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1632  vop0);
1633  _mm_prefetch(
1634  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1635  vop8 = _mm256_fmadd_ps(
1636  vwgt,
1637  _mm256_cvtph_ps(
1638  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1639  vop8);
1640  // skip unnecessary prefetch of (&ip_next_T0[8])
1641  vop16 = _mm256_fmadd_ps(
1642  vwgt,
1643  _mm256_cvtph_ps(
1644  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1645  vop16);
1646  // skip unnecessary prefetch of (&ip_next_T0[16])
1647  vop24 = _mm256_fmadd_ps(
1648  vwgt,
1649  _mm256_cvtph_ps(
1650  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1651  vop24);
1652  // skip unnecessary prefetch of (&ip_next_T0[24])
1653  vop32 = _mm256_fmadd_ps(
1654  vwgt,
1655  _mm256_cvtph_ps(
1656  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1657  vop32);
1658  _mm_prefetch(
1659  reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1660  vop40 = _mm256_fmadd_ps(
1661  vwgt,
1662  _mm256_cvtph_ps(
1663  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1664  vop40);
1665  // skip unnecessary prefetch of (&ip_next_T0[40])
1666  vop48 = _mm256_fmadd_ps(
1667  vwgt,
1668  _mm256_cvtph_ps(
1669  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1670  vop48);
1671  // skip unnecessary prefetch of (&ip_next_T0[48])
1672  vop56 = _mm256_fmadd_ps(
1673  vwgt,
1674  _mm256_cvtph_ps(
1675  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1676  vop56);
1677  // skip unnecessary prefetch of (&ip_next_T0[56])
1678  }
1679  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1680  _mm256_storeu_ps(&op[0], vop0);
1681  _mm256_storeu_ps(&op[8], vop8);
1682  _mm256_storeu_ps(&op[16], vop16);
1683  _mm256_storeu_ps(&op[24], vop24);
1684  _mm256_storeu_ps(&op[32], vop32);
1685  _mm256_storeu_ps(&op[40], vop40);
1686  _mm256_storeu_ps(&op[48], vop48);
1687  _mm256_storeu_ps(&op[56], vop56);
1688  } else {
1689  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1690  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1691  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1692  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1693  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1694  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1695  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1696  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1697  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1698  }
1699  }
1700  } else if (block_size == 32) {
1701  // unrolling 4 times
1702  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1703  float* op = &out[rangeIndex * block_size];
1704  __m256 vop0 = _mm256_setzero_ps();
1705  __m256 vop8 = _mm256_setzero_ps();
1706  __m256 vop16 = _mm256_setzero_ps();
1707  __m256 vop24 = _mm256_setzero_ps();
1708  if (dataInd + lengths[rangeIndex] > index_size) {
1709  return false;
1710  }
1711  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1712  ++dataInd) {
1713  const int64_t idx = indices[dataInd];
1714  if (idx < 0 || idx >= data_size) {
1715  return false;
1716  }
1717  float wgt = 1.f;
1718  if (weights) {
1719  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1720  }
1721  __m256 vwgt = _mm256_set1_ps(wgt);
1722  const at::Half* ip = &input[idx * fused_block_size];
1723  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1724  ? (dataInd + prefdist_T0)
1725  : dataInd;
1726  const int64_t idx_pref_T0 = indices[next_T0];
1727  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1728  return false;
1729  }
1730  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1731  vop0 = _mm256_fmadd_ps(
1732  vwgt,
1733  _mm256_cvtph_ps(
1734  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1735  vop0);
1736  _mm_prefetch(
1737  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1738  vop8 = _mm256_fmadd_ps(
1739  vwgt,
1740  _mm256_cvtph_ps(
1741  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1742  vop8);
1743  // skip unnecessary prefetch of (&ip_next_T0[8])
1744  vop16 = _mm256_fmadd_ps(
1745  vwgt,
1746  _mm256_cvtph_ps(
1747  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1748  vop16);
1749  // skip unnecessary prefetch of (&ip_next_T0[16])
1750  vop24 = _mm256_fmadd_ps(
1751  vwgt,
1752  _mm256_cvtph_ps(
1753  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1754  vop24);
1755  // skip unnecessary prefetch of (&ip_next_T0[24])
1756  }
1757  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1758  _mm256_storeu_ps(&op[0], vop0);
1759  _mm256_storeu_ps(&op[8], vop8);
1760  _mm256_storeu_ps(&op[16], vop16);
1761  _mm256_storeu_ps(&op[24], vop24);
1762  } else {
1763  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1764  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1765  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1766  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1767  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1768  }
1769  }
1770  } else if (block_size == 16) {
1771  // unrolling 2 times
1772  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1773  float* op = &out[rangeIndex * block_size];
1774  __m256 vop0 = _mm256_setzero_ps();
1775  __m256 vop8 = _mm256_setzero_ps();
1776  if (dataInd + lengths[rangeIndex] > index_size) {
1777  return false;
1778  }
1779  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1780  ++dataInd) {
1781  const int64_t idx = indices[dataInd];
1782  if (idx < 0 || idx >= data_size) {
1783  return false;
1784  }
1785  float wgt = 1.f;
1786  if (weights) {
1787  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1788  }
1789  __m256 vwgt = _mm256_set1_ps(wgt);
1790  const at::Half* ip = &input[idx * fused_block_size];
1791  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1792  ? (dataInd + prefdist_T0)
1793  : dataInd;
1794  const int64_t idx_pref_T0 = indices[next_T0];
1795  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1796  return false;
1797  }
1798  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1799  vop0 = _mm256_fmadd_ps(
1800  vwgt,
1801  _mm256_cvtph_ps(
1802  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1803  vop0);
1804  _mm_prefetch(
1805  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1806  vop8 = _mm256_fmadd_ps(
1807  vwgt,
1808  _mm256_cvtph_ps(
1809  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1810  vop8);
1811  // skip unnecessary prefetch of (&ip_next_T0[8])
1812  }
1813  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
1814  _mm256_storeu_ps(&op[0], vop0);
1815  _mm256_storeu_ps(&op[8], vop8);
1816  } else {
1817  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1818  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1819  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1820  }
1821  }
1822  } else {
1823  // generic code
1824  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1825  float* op = &out[rangeIndex * block_size];
1826  int64_t j = 0;
1827  for (; j + 8 <= block_size; j += 8) {
1828  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1829  }
1830  for (; j < block_size; j++) {
1831  op[j] = 0.0f;
1832  }
1833  if (dataInd + lengths[rangeIndex] > index_size) {
1834  return false;
1835  }
1836  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1837  ++dataInd) {
1838  const int64_t idx = indices[dataInd];
1839  if (idx < 0 || idx >= data_size) {
1840  return false;
1841  }
1842  float wgt = 1.f;
1843  if (weights) {
1844  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1845  }
1846  __m256 vwgt = _mm256_set1_ps(wgt);
1847  const at::Half* ip = &input[idx * fused_block_size];
1848  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1849  ? (dataInd + prefdist_T0)
1850  : dataInd;
1851  const int64_t idx_pref_T0 = indices[next_T0];
1852  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1853  return false;
1854  }
1855  const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1856  j = 0;
1857  for (; j + 8 <= block_size; j += 8) {
1858  _mm256_storeu_ps(
1859  &op[j],
1860  _mm256_fmadd_ps(
1861  vwgt,
1862  _mm256_cvtph_ps(_mm_loadu_si128(
1863  reinterpret_cast<const __m128i*>(&ip[j]))),
1864  _mm256_loadu_ps(&op[j])));
1865  _mm_prefetch(
1866  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1867  }
1868  alignas(64) at::Half vtmp1[8];
1869  for (; j < block_size; j++) {
1870  vtmp1[0] = ip[j];
1871  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1872  op[j] += wgt * ((float*)(&vtmp2))[0];
1873  }
1874  }
1875  if (normalize_by_lengths && lengths[rangeIndex]) {
1876  float len_inv = 1.0f / lengths[rangeIndex];
1877  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1878  j = 0;
1879  for (; j + 8 <= block_size; j += 8) {
1880  _mm256_storeu_ps(
1881  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1882  }
1883  for (; j < block_size; j++) {
1884  op[j] = len_inv * op[j];
1885  }
1886  }
1887  }
1888  }
1889  return dataInd == index_size;
1890 }
1891 bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_false__avx2_fma(
1892  const int64_t block_size,
1893  const int64_t output_size,
1894  const int64_t index_size,
1895  const int64_t data_size,
1896  const at::Half* input,
1897  const int64_t* indices,
1898  const int* lengths,
1899  const float* weights,
1900  bool normalize_by_lengths,
1901  float* out) {
1902  return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma<false>(
1903  block_size,
1904  output_size,
1905  index_size,
1906  data_size,
1907  input,
1908  indices,
1909  lengths,
1910  weights,
1911  normalize_by_lengths,
1912  out);
1913 }
1914 bool Fused8BitRowwiseEmbeddingLookup_int64_t_half_float_true__avx2_fma(
1915  const int64_t block_size,
1916  const int64_t output_size,
1917  const int64_t index_size,
1918  const int64_t data_size,
1919  const at::Half* input,
1920  const int64_t* indices,
1921  const int* lengths,
1922  const float* weights,
1923  bool normalize_by_lengths,
1924  float* out) {
1925  return Fused8BitRowwiseEmbeddingLookup_int64_t_half_float__avx2_fma<true>(
1926  block_size,
1927  output_size,
1928  index_size,
1929  data_size,
1930  input,
1931  indices,
1932  lengths,
1933  weights,
1934  normalize_by_lengths,
1935  out);
1936 }
1937 
1938 template <bool IS_WEIGHT_POSITIONAL>
1939 static bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1940  const int64_t block_size,
1941  const int64_t output_size,
1942  const int64_t index_size,
1943  const int64_t data_size,
1944  const uint8_t* input,
1945  const int* indices,
1946  const int* lengths,
1947  const float* weights,
1948  bool normalize_by_lengths,
1949  float* out) {
1950  const int prefdist_T0 = 16;
1951  const int fused_block_size = block_size + 8;
1952  int dataInd = 0;
1953  if (block_size == 128) {
1954  // unrolling 16 times
1955  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1956  float* op = &out[rangeIndex * block_size];
1957  __m256 vop0 = _mm256_setzero_ps();
1958  __m256 vop8 = _mm256_setzero_ps();
1959  __m256 vop16 = _mm256_setzero_ps();
1960  __m256 vop24 = _mm256_setzero_ps();
1961  __m256 vop32 = _mm256_setzero_ps();
1962  __m256 vop40 = _mm256_setzero_ps();
1963  __m256 vop48 = _mm256_setzero_ps();
1964  __m256 vop56 = _mm256_setzero_ps();
1965  __m256 vop64 = _mm256_setzero_ps();
1966  __m256 vop72 = _mm256_setzero_ps();
1967  __m256 vop80 = _mm256_setzero_ps();
1968  __m256 vop88 = _mm256_setzero_ps();
1969  __m256 vop96 = _mm256_setzero_ps();
1970  __m256 vop104 = _mm256_setzero_ps();
1971  __m256 vop112 = _mm256_setzero_ps();
1972  __m256 vop120 = _mm256_setzero_ps();
1973  if (dataInd + lengths[rangeIndex] > index_size) {
1974  return false;
1975  }
1976  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
1977  ++dataInd) {
1978  const int idx = indices[dataInd];
1979  if (idx < 0 || idx >= data_size) {
1980  return false;
1981  }
1982  float wgt = 1.f;
1983  float bio;
1984  if (weights) {
1985  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1986  }
1987  const float* scale_bias = reinterpret_cast<const float*>(
1988  &input[idx * fused_block_size + block_size]);
1989  bio = wgt * scale_bias[1];
1990  wgt = wgt * scale_bias[0];
1991  __m256 vbio = _mm256_set1_ps(bio);
1992  __m256 vwgt = _mm256_set1_ps(wgt);
1993  const uint8_t* ip = &input[idx * fused_block_size];
1994  const int next_T0 = (dataInd < index_size - prefdist_T0)
1995  ? (dataInd + prefdist_T0)
1996  : dataInd;
1997  const int idx_pref_T0 = indices[next_T0];
1998  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1999  return false;
2000  }
2001  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2002  vop0 = _mm256_fmadd_ps(
2003  vwgt,
2004  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2005  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2006  _mm256_add_ps(vop0, vbio));
2007  _mm_prefetch(
2008  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2009  vop8 = _mm256_fmadd_ps(
2010  vwgt,
2011  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2012  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2013  _mm256_add_ps(vop8, vbio));
2014  // skip unnecessary prefetch of (&ip_next_T0[8])
2015  vop16 = _mm256_fmadd_ps(
2016  vwgt,
2017  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2018  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2019  _mm256_add_ps(vop16, vbio));
2020  // skip unnecessary prefetch of (&ip_next_T0[16])
2021  vop24 = _mm256_fmadd_ps(
2022  vwgt,
2023  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2024  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2025  _mm256_add_ps(vop24, vbio));
2026  // skip unnecessary prefetch of (&ip_next_T0[24])
2027  vop32 = _mm256_fmadd_ps(
2028  vwgt,
2029  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2030  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2031  _mm256_add_ps(vop32, vbio));
2032  // skip unnecessary prefetch of (&ip_next_T0[32])
2033  vop40 = _mm256_fmadd_ps(
2034  vwgt,
2035  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2036  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2037  _mm256_add_ps(vop40, vbio));
2038  // skip unnecessary prefetch of (&ip_next_T0[40])
2039  vop48 = _mm256_fmadd_ps(
2040  vwgt,
2041  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2042  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2043  _mm256_add_ps(vop48, vbio));
2044  // skip unnecessary prefetch of (&ip_next_T0[48])
2045  vop56 = _mm256_fmadd_ps(
2046  vwgt,
2047  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2048  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2049  _mm256_add_ps(vop56, vbio));
2050  // skip unnecessary prefetch of (&ip_next_T0[56])
2051  vop64 = _mm256_fmadd_ps(
2052  vwgt,
2053  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2054  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2055  _mm256_add_ps(vop64, vbio));
2056  _mm_prefetch(
2057  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2058  vop72 = _mm256_fmadd_ps(
2059  vwgt,
2060  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2061  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2062  _mm256_add_ps(vop72, vbio));
2063  // skip unnecessary prefetch of (&ip_next_T0[72])
2064  vop80 = _mm256_fmadd_ps(
2065  vwgt,
2066  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2067  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2068  _mm256_add_ps(vop80, vbio));
2069  // skip unnecessary prefetch of (&ip_next_T0[80])
2070  vop88 = _mm256_fmadd_ps(
2071  vwgt,
2072  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2073  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2074  _mm256_add_ps(vop88, vbio));
2075  // skip unnecessary prefetch of (&ip_next_T0[88])
2076  vop96 = _mm256_fmadd_ps(
2077  vwgt,
2078  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2079  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2080  _mm256_add_ps(vop96, vbio));
2081  // skip unnecessary prefetch of (&ip_next_T0[96])
2082  vop104 = _mm256_fmadd_ps(
2083  vwgt,
2084  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2085  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2086  _mm256_add_ps(vop104, vbio));
2087  // skip unnecessary prefetch of (&ip_next_T0[104])
2088  vop112 = _mm256_fmadd_ps(
2089  vwgt,
2090  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2091  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2092  _mm256_add_ps(vop112, vbio));
2093  // skip unnecessary prefetch of (&ip_next_T0[112])
2094  vop120 = _mm256_fmadd_ps(
2095  vwgt,
2096  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2097  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2098  _mm256_add_ps(vop120, vbio));
2099  // skip unnecessary prefetch of (&ip_next_T0[120])
2100  }
2101  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2102  _mm256_storeu_ps(&op[0], vop0);
2103  _mm256_storeu_ps(&op[8], vop8);
2104  _mm256_storeu_ps(&op[16], vop16);
2105  _mm256_storeu_ps(&op[24], vop24);
2106  _mm256_storeu_ps(&op[32], vop32);
2107  _mm256_storeu_ps(&op[40], vop40);
2108  _mm256_storeu_ps(&op[48], vop48);
2109  _mm256_storeu_ps(&op[56], vop56);
2110  _mm256_storeu_ps(&op[64], vop64);
2111  _mm256_storeu_ps(&op[72], vop72);
2112  _mm256_storeu_ps(&op[80], vop80);
2113  _mm256_storeu_ps(&op[88], vop88);
2114  _mm256_storeu_ps(&op[96], vop96);
2115  _mm256_storeu_ps(&op[104], vop104);
2116  _mm256_storeu_ps(&op[112], vop112);
2117  _mm256_storeu_ps(&op[120], vop120);
2118  } else {
2119  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2120  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2121  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2122  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2123  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2124  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2125  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2126  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2127  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2128  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2129  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2130  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2131  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2132  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2133  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2134  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2135  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2136  }
2137  }
2138  } else if (block_size == 64) {
2139  // unrolling 8 times
2140  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2141  float* op = &out[rangeIndex * block_size];
2142  __m256 vop0 = _mm256_setzero_ps();
2143  __m256 vop8 = _mm256_setzero_ps();
2144  __m256 vop16 = _mm256_setzero_ps();
2145  __m256 vop24 = _mm256_setzero_ps();
2146  __m256 vop32 = _mm256_setzero_ps();
2147  __m256 vop40 = _mm256_setzero_ps();
2148  __m256 vop48 = _mm256_setzero_ps();
2149  __m256 vop56 = _mm256_setzero_ps();
2150  if (dataInd + lengths[rangeIndex] > index_size) {
2151  return false;
2152  }
2153  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2154  ++dataInd) {
2155  const int idx = indices[dataInd];
2156  if (idx < 0 || idx >= data_size) {
2157  return false;
2158  }
2159  float wgt = 1.f;
2160  float bio;
2161  if (weights) {
2162  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2163  }
2164  const float* scale_bias = reinterpret_cast<const float*>(
2165  &input[idx * fused_block_size + block_size]);
2166  bio = wgt * scale_bias[1];
2167  wgt = wgt * scale_bias[0];
2168  __m256 vbio = _mm256_set1_ps(bio);
2169  __m256 vwgt = _mm256_set1_ps(wgt);
2170  const uint8_t* ip = &input[idx * fused_block_size];
2171  const int next_T0 = (dataInd < index_size - prefdist_T0)
2172  ? (dataInd + prefdist_T0)
2173  : dataInd;
2174  const int idx_pref_T0 = indices[next_T0];
2175  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2176  return false;
2177  }
2178  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2179  vop0 = _mm256_fmadd_ps(
2180  vwgt,
2181  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2182  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2183  _mm256_add_ps(vop0, vbio));
2184  _mm_prefetch(
2185  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2186  vop8 = _mm256_fmadd_ps(
2187  vwgt,
2188  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2189  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2190  _mm256_add_ps(vop8, vbio));
2191  // skip unnecessary prefetch of (&ip_next_T0[8])
2192  vop16 = _mm256_fmadd_ps(
2193  vwgt,
2194  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2195  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2196  _mm256_add_ps(vop16, vbio));
2197  // skip unnecessary prefetch of (&ip_next_T0[16])
2198  vop24 = _mm256_fmadd_ps(
2199  vwgt,
2200  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2201  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2202  _mm256_add_ps(vop24, vbio));
2203  // skip unnecessary prefetch of (&ip_next_T0[24])
2204  vop32 = _mm256_fmadd_ps(
2205  vwgt,
2206  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2207  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2208  _mm256_add_ps(vop32, vbio));
2209  // skip unnecessary prefetch of (&ip_next_T0[32])
2210  vop40 = _mm256_fmadd_ps(
2211  vwgt,
2212  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2213  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2214  _mm256_add_ps(vop40, vbio));
2215  // skip unnecessary prefetch of (&ip_next_T0[40])
2216  vop48 = _mm256_fmadd_ps(
2217  vwgt,
2218  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2219  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2220  _mm256_add_ps(vop48, vbio));
2221  // skip unnecessary prefetch of (&ip_next_T0[48])
2222  vop56 = _mm256_fmadd_ps(
2223  vwgt,
2224  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2225  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2226  _mm256_add_ps(vop56, vbio));
2227  // skip unnecessary prefetch of (&ip_next_T0[56])
2228  }
2229  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2230  _mm256_storeu_ps(&op[0], vop0);
2231  _mm256_storeu_ps(&op[8], vop8);
2232  _mm256_storeu_ps(&op[16], vop16);
2233  _mm256_storeu_ps(&op[24], vop24);
2234  _mm256_storeu_ps(&op[32], vop32);
2235  _mm256_storeu_ps(&op[40], vop40);
2236  _mm256_storeu_ps(&op[48], vop48);
2237  _mm256_storeu_ps(&op[56], vop56);
2238  } else {
2239  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2240  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2241  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2242  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2243  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2244  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2245  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2246  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2247  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2248  }
2249  }
2250  } else if (block_size == 32) {
2251  // unrolling 4 times
2252  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2253  float* op = &out[rangeIndex * block_size];
2254  __m256 vop0 = _mm256_setzero_ps();
2255  __m256 vop8 = _mm256_setzero_ps();
2256  __m256 vop16 = _mm256_setzero_ps();
2257  __m256 vop24 = _mm256_setzero_ps();
2258  if (dataInd + lengths[rangeIndex] > index_size) {
2259  return false;
2260  }
2261  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2262  ++dataInd) {
2263  const int idx = indices[dataInd];
2264  if (idx < 0 || idx >= data_size) {
2265  return false;
2266  }
2267  float wgt = 1.f;
2268  float bio;
2269  if (weights) {
2270  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2271  }
2272  const float* scale_bias = reinterpret_cast<const float*>(
2273  &input[idx * fused_block_size + block_size]);
2274  bio = wgt * scale_bias[1];
2275  wgt = wgt * scale_bias[0];
2276  __m256 vbio = _mm256_set1_ps(bio);
2277  __m256 vwgt = _mm256_set1_ps(wgt);
2278  const uint8_t* ip = &input[idx * fused_block_size];
2279  const int next_T0 = (dataInd < index_size - prefdist_T0)
2280  ? (dataInd + prefdist_T0)
2281  : dataInd;
2282  const int idx_pref_T0 = indices[next_T0];
2283  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2284  return false;
2285  }
2286  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2287  vop0 = _mm256_fmadd_ps(
2288  vwgt,
2289  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2290  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2291  _mm256_add_ps(vop0, vbio));
2292  _mm_prefetch(
2293  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2294  vop8 = _mm256_fmadd_ps(
2295  vwgt,
2296  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2297  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2298  _mm256_add_ps(vop8, vbio));
2299  // skip unnecessary prefetch of (&ip_next_T0[8])
2300  vop16 = _mm256_fmadd_ps(
2301  vwgt,
2302  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2303  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2304  _mm256_add_ps(vop16, vbio));
2305  // skip unnecessary prefetch of (&ip_next_T0[16])
2306  vop24 = _mm256_fmadd_ps(
2307  vwgt,
2308  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2309  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2310  _mm256_add_ps(vop24, vbio));
2311  // skip unnecessary prefetch of (&ip_next_T0[24])
2312  }
2313  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2314  _mm256_storeu_ps(&op[0], vop0);
2315  _mm256_storeu_ps(&op[8], vop8);
2316  _mm256_storeu_ps(&op[16], vop16);
2317  _mm256_storeu_ps(&op[24], vop24);
2318  } else {
2319  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2320  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2321  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2322  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2323  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2324  }
2325  }
2326  } else if (block_size == 16) {
2327  // unrolling 2 times
2328  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2329  float* op = &out[rangeIndex * block_size];
2330  __m256 vop0 = _mm256_setzero_ps();
2331  __m256 vop8 = _mm256_setzero_ps();
2332  if (dataInd + lengths[rangeIndex] > index_size) {
2333  return false;
2334  }
2335  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2336  ++dataInd) {
2337  const int idx = indices[dataInd];
2338  if (idx < 0 || idx >= data_size) {
2339  return false;
2340  }
2341  float wgt = 1.f;
2342  float bio;
2343  if (weights) {
2344  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2345  }
2346  const float* scale_bias = reinterpret_cast<const float*>(
2347  &input[idx * fused_block_size + block_size]);
2348  bio = wgt * scale_bias[1];
2349  wgt = wgt * scale_bias[0];
2350  __m256 vbio = _mm256_set1_ps(bio);
2351  __m256 vwgt = _mm256_set1_ps(wgt);
2352  const uint8_t* ip = &input[idx * fused_block_size];
2353  const int next_T0 = (dataInd < index_size - prefdist_T0)
2354  ? (dataInd + prefdist_T0)
2355  : dataInd;
2356  const int idx_pref_T0 = indices[next_T0];
2357  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2358  return false;
2359  }
2360  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2361  vop0 = _mm256_fmadd_ps(
2362  vwgt,
2363  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2364  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2365  _mm256_add_ps(vop0, vbio));
2366  _mm_prefetch(
2367  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2368  vop8 = _mm256_fmadd_ps(
2369  vwgt,
2370  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2371  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2372  _mm256_add_ps(vop8, vbio));
2373  // skip unnecessary prefetch of (&ip_next_T0[8])
2374  }
2375  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2376  _mm256_storeu_ps(&op[0], vop0);
2377  _mm256_storeu_ps(&op[8], vop8);
2378  } else {
2379  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2380  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2381  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2382  }
2383  }
2384  } else {
2385  // generic code
2386  for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2387  float* op = &out[rangeIndex * block_size];
2388  int64_t j = 0;
2389  for (; j + 8 <= block_size; j += 8) {
2390  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2391  }
2392  for (; j < block_size; j++) {
2393  op[j] = 0.0f;
2394  }
2395  if (dataInd + lengths[rangeIndex] > index_size) {
2396  return false;
2397  }
2398  for (int start = dataInd; dataInd < start + lengths[rangeIndex];
2399  ++dataInd) {
2400  const int idx = indices[dataInd];
2401  if (idx < 0 || idx >= data_size) {
2402  return false;
2403  }
2404  float wgt = 1.f;
2405  float bio;
2406  if (weights) {
2407  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2408  }
2409  const float* scale_bias = reinterpret_cast<const float*>(
2410  &input[idx * fused_block_size + block_size]);
2411  bio = wgt * scale_bias[1];
2412  wgt = wgt * scale_bias[0];
2413  __m256 vbio = _mm256_set1_ps(bio);
2414  __m256 vwgt = _mm256_set1_ps(wgt);
2415  const uint8_t* ip = &input[idx * fused_block_size];
2416  const int next_T0 = (dataInd < index_size - prefdist_T0)
2417  ? (dataInd + prefdist_T0)
2418  : dataInd;
2419  const int idx_pref_T0 = indices[next_T0];
2420  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2421  return false;
2422  }
2423  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2424  j = 0;
2425  for (; j + 8 <= block_size; j += 8) {
2426  _mm256_storeu_ps(
2427  &op[j],
2428  _mm256_fmadd_ps(
2429  vwgt,
2430  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2431  reinterpret_cast<const __m128i*>(&ip[j])))),
2432  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2433  _mm_prefetch(
2434  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2435  }
2436  for (; j < block_size; j++) {
2437  op[j] += wgt * ((float)ip[j]) + bio;
2438  }
2439  }
2440  if (normalize_by_lengths && lengths[rangeIndex]) {
2441  float len_inv = 1.0f / lengths[rangeIndex];
2442  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2443  j = 0;
2444  for (; j + 8 <= block_size; j += 8) {
2445  _mm256_storeu_ps(
2446  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2447  }
2448  for (; j < block_size; j++) {
2449  op[j] = len_inv * op[j];
2450  }
2451  }
2452  }
2453  }
2454  return dataInd == index_size;
2455 }
2456 bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_false__avx2_fma(
2457  const int64_t block_size,
2458  const int64_t output_size,
2459  const int64_t index_size,
2460  const int64_t data_size,
2461  const uint8_t* input,
2462  const int* indices,
2463  const int* lengths,
2464  const float* weights,
2465  bool normalize_by_lengths,
2466  float* out) {
2467  return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<false>(
2468  block_size,
2469  output_size,
2470  index_size,
2471  data_size,
2472  input,
2473  indices,
2474  lengths,
2475  weights,
2476  normalize_by_lengths,
2477  out);
2478 }
2479 bool Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float_true__avx2_fma(
2480  const int64_t block_size,
2481  const int64_t output_size,
2482  const int64_t index_size,
2483  const int64_t data_size,
2484  const uint8_t* input,
2485  const int* indices,
2486  const int* lengths,
2487  const float* weights,
2488  bool normalize_by_lengths,
2489  float* out) {
2490  return Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma<true>(
2491  block_size,
2492  output_size,
2493  index_size,
2494  data_size,
2495  input,
2496  indices,
2497  lengths,
2498  weights,
2499  normalize_by_lengths,
2500  out);
2501 }
2502 
2503 template <bool IS_WEIGHT_POSITIONAL>
2504 static bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2505  const int64_t block_size,
2506  const int64_t output_size,
2507  const int64_t index_size,
2508  const int64_t data_size,
2509  const uint8_t* input,
2510  const int64_t* indices,
2511  const int* lengths,
2512  const float* weights,
2513  bool normalize_by_lengths,
2514  float* out) {
2515  const int64_t prefdist_T0 = 16;
2516  const int64_t fused_block_size = block_size + 8;
2517  int64_t dataInd = 0;
2518  if (block_size == 128) {
2519  // unrolling 16 times
2520  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2521  float* op = &out[rangeIndex * block_size];
2522  __m256 vop0 = _mm256_setzero_ps();
2523  __m256 vop8 = _mm256_setzero_ps();
2524  __m256 vop16 = _mm256_setzero_ps();
2525  __m256 vop24 = _mm256_setzero_ps();
2526  __m256 vop32 = _mm256_setzero_ps();
2527  __m256 vop40 = _mm256_setzero_ps();
2528  __m256 vop48 = _mm256_setzero_ps();
2529  __m256 vop56 = _mm256_setzero_ps();
2530  __m256 vop64 = _mm256_setzero_ps();
2531  __m256 vop72 = _mm256_setzero_ps();
2532  __m256 vop80 = _mm256_setzero_ps();
2533  __m256 vop88 = _mm256_setzero_ps();
2534  __m256 vop96 = _mm256_setzero_ps();
2535  __m256 vop104 = _mm256_setzero_ps();
2536  __m256 vop112 = _mm256_setzero_ps();
2537  __m256 vop120 = _mm256_setzero_ps();
2538  if (dataInd + lengths[rangeIndex] > index_size) {
2539  return false;
2540  }
2541  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2542  ++dataInd) {
2543  const int64_t idx = indices[dataInd];
2544  if (idx < 0 || idx >= data_size) {
2545  return false;
2546  }
2547  float wgt = 1.f;
2548  float bio;
2549  if (weights) {
2550  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2551  }
2552  const float* scale_bias = reinterpret_cast<const float*>(
2553  &input[idx * fused_block_size + block_size]);
2554  bio = wgt * scale_bias[1];
2555  wgt = wgt * scale_bias[0];
2556  __m256 vbio = _mm256_set1_ps(bio);
2557  __m256 vwgt = _mm256_set1_ps(wgt);
2558  const uint8_t* ip = &input[idx * fused_block_size];
2559  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2560  ? (dataInd + prefdist_T0)
2561  : dataInd;
2562  const int64_t idx_pref_T0 = indices[next_T0];
2563  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2564  return false;
2565  }
2566  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2567  vop0 = _mm256_fmadd_ps(
2568  vwgt,
2569  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2570  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2571  _mm256_add_ps(vop0, vbio));
2572  _mm_prefetch(
2573  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2574  vop8 = _mm256_fmadd_ps(
2575  vwgt,
2576  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2577  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2578  _mm256_add_ps(vop8, vbio));
2579  // skip unnecessary prefetch of (&ip_next_T0[8])
2580  vop16 = _mm256_fmadd_ps(
2581  vwgt,
2582  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2583  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2584  _mm256_add_ps(vop16, vbio));
2585  // skip unnecessary prefetch of (&ip_next_T0[16])
2586  vop24 = _mm256_fmadd_ps(
2587  vwgt,
2588  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2589  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2590  _mm256_add_ps(vop24, vbio));
2591  // skip unnecessary prefetch of (&ip_next_T0[24])
2592  vop32 = _mm256_fmadd_ps(
2593  vwgt,
2594  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2595  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2596  _mm256_add_ps(vop32, vbio));
2597  // skip unnecessary prefetch of (&ip_next_T0[32])
2598  vop40 = _mm256_fmadd_ps(
2599  vwgt,
2600  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2601  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2602  _mm256_add_ps(vop40, vbio));
2603  // skip unnecessary prefetch of (&ip_next_T0[40])
2604  vop48 = _mm256_fmadd_ps(
2605  vwgt,
2606  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2607  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2608  _mm256_add_ps(vop48, vbio));
2609  // skip unnecessary prefetch of (&ip_next_T0[48])
2610  vop56 = _mm256_fmadd_ps(
2611  vwgt,
2612  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2613  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2614  _mm256_add_ps(vop56, vbio));
2615  // skip unnecessary prefetch of (&ip_next_T0[56])
2616  vop64 = _mm256_fmadd_ps(
2617  vwgt,
2618  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2619  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2620  _mm256_add_ps(vop64, vbio));
2621  _mm_prefetch(
2622  reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2623  vop72 = _mm256_fmadd_ps(
2624  vwgt,
2625  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2626  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2627  _mm256_add_ps(vop72, vbio));
2628  // skip unnecessary prefetch of (&ip_next_T0[72])
2629  vop80 = _mm256_fmadd_ps(
2630  vwgt,
2631  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2632  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2633  _mm256_add_ps(vop80, vbio));
2634  // skip unnecessary prefetch of (&ip_next_T0[80])
2635  vop88 = _mm256_fmadd_ps(
2636  vwgt,
2637  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2638  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2639  _mm256_add_ps(vop88, vbio));
2640  // skip unnecessary prefetch of (&ip_next_T0[88])
2641  vop96 = _mm256_fmadd_ps(
2642  vwgt,
2643  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2644  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2645  _mm256_add_ps(vop96, vbio));
2646  // skip unnecessary prefetch of (&ip_next_T0[96])
2647  vop104 = _mm256_fmadd_ps(
2648  vwgt,
2649  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2650  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2651  _mm256_add_ps(vop104, vbio));
2652  // skip unnecessary prefetch of (&ip_next_T0[104])
2653  vop112 = _mm256_fmadd_ps(
2654  vwgt,
2655  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2656  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2657  _mm256_add_ps(vop112, vbio));
2658  // skip unnecessary prefetch of (&ip_next_T0[112])
2659  vop120 = _mm256_fmadd_ps(
2660  vwgt,
2661  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2662  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2663  _mm256_add_ps(vop120, vbio));
2664  // skip unnecessary prefetch of (&ip_next_T0[120])
2665  }
2666  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2667  _mm256_storeu_ps(&op[0], vop0);
2668  _mm256_storeu_ps(&op[8], vop8);
2669  _mm256_storeu_ps(&op[16], vop16);
2670  _mm256_storeu_ps(&op[24], vop24);
2671  _mm256_storeu_ps(&op[32], vop32);
2672  _mm256_storeu_ps(&op[40], vop40);
2673  _mm256_storeu_ps(&op[48], vop48);
2674  _mm256_storeu_ps(&op[56], vop56);
2675  _mm256_storeu_ps(&op[64], vop64);
2676  _mm256_storeu_ps(&op[72], vop72);
2677  _mm256_storeu_ps(&op[80], vop80);
2678  _mm256_storeu_ps(&op[88], vop88);
2679  _mm256_storeu_ps(&op[96], vop96);
2680  _mm256_storeu_ps(&op[104], vop104);
2681  _mm256_storeu_ps(&op[112], vop112);
2682  _mm256_storeu_ps(&op[120], vop120);
2683  } else {
2684  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2685  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2686  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2687  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2688  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2689  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2690  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2691  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2692  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2693  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2694  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2695  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2696  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2697  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2698  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2699  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2700  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2701  }
2702  }
2703  } else if (block_size == 64) {
2704  // unrolling 8 times
2705  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2706  float* op = &out[rangeIndex * block_size];
2707  __m256 vop0 = _mm256_setzero_ps();
2708  __m256 vop8 = _mm256_setzero_ps();
2709  __m256 vop16 = _mm256_setzero_ps();
2710  __m256 vop24 = _mm256_setzero_ps();
2711  __m256 vop32 = _mm256_setzero_ps();
2712  __m256 vop40 = _mm256_setzero_ps();
2713  __m256 vop48 = _mm256_setzero_ps();
2714  __m256 vop56 = _mm256_setzero_ps();
2715  if (dataInd + lengths[rangeIndex] > index_size) {
2716  return false;
2717  }
2718  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2719  ++dataInd) {
2720  const int64_t idx = indices[dataInd];
2721  if (idx < 0 || idx >= data_size) {
2722  return false;
2723  }
2724  float wgt = 1.f;
2725  float bio;
2726  if (weights) {
2727  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2728  }
2729  const float* scale_bias = reinterpret_cast<const float*>(
2730  &input[idx * fused_block_size + block_size]);
2731  bio = wgt * scale_bias[1];
2732  wgt = wgt * scale_bias[0];
2733  __m256 vbio = _mm256_set1_ps(bio);
2734  __m256 vwgt = _mm256_set1_ps(wgt);
2735  const uint8_t* ip = &input[idx * fused_block_size];
2736  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2737  ? (dataInd + prefdist_T0)
2738  : dataInd;
2739  const int64_t idx_pref_T0 = indices[next_T0];
2740  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2741  return false;
2742  }
2743  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2744  vop0 = _mm256_fmadd_ps(
2745  vwgt,
2746  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2747  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2748  _mm256_add_ps(vop0, vbio));
2749  _mm_prefetch(
2750  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2751  vop8 = _mm256_fmadd_ps(
2752  vwgt,
2753  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2754  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2755  _mm256_add_ps(vop8, vbio));
2756  // skip unnecessary prefetch of (&ip_next_T0[8])
2757  vop16 = _mm256_fmadd_ps(
2758  vwgt,
2759  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2760  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2761  _mm256_add_ps(vop16, vbio));
2762  // skip unnecessary prefetch of (&ip_next_T0[16])
2763  vop24 = _mm256_fmadd_ps(
2764  vwgt,
2765  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2766  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2767  _mm256_add_ps(vop24, vbio));
2768  // skip unnecessary prefetch of (&ip_next_T0[24])
2769  vop32 = _mm256_fmadd_ps(
2770  vwgt,
2771  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2772  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2773  _mm256_add_ps(vop32, vbio));
2774  // skip unnecessary prefetch of (&ip_next_T0[32])
2775  vop40 = _mm256_fmadd_ps(
2776  vwgt,
2777  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2778  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2779  _mm256_add_ps(vop40, vbio));
2780  // skip unnecessary prefetch of (&ip_next_T0[40])
2781  vop48 = _mm256_fmadd_ps(
2782  vwgt,
2783  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2784  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2785  _mm256_add_ps(vop48, vbio));
2786  // skip unnecessary prefetch of (&ip_next_T0[48])
2787  vop56 = _mm256_fmadd_ps(
2788  vwgt,
2789  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2790  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2791  _mm256_add_ps(vop56, vbio));
2792  // skip unnecessary prefetch of (&ip_next_T0[56])
2793  }
2794  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2795  _mm256_storeu_ps(&op[0], vop0);
2796  _mm256_storeu_ps(&op[8], vop8);
2797  _mm256_storeu_ps(&op[16], vop16);
2798  _mm256_storeu_ps(&op[24], vop24);
2799  _mm256_storeu_ps(&op[32], vop32);
2800  _mm256_storeu_ps(&op[40], vop40);
2801  _mm256_storeu_ps(&op[48], vop48);
2802  _mm256_storeu_ps(&op[56], vop56);
2803  } else {
2804  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2805  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2806  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2807  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2808  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2809  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2810  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2811  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2812  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2813  }
2814  }
2815  } else if (block_size == 32) {
2816  // unrolling 4 times
2817  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2818  float* op = &out[rangeIndex * block_size];
2819  __m256 vop0 = _mm256_setzero_ps();
2820  __m256 vop8 = _mm256_setzero_ps();
2821  __m256 vop16 = _mm256_setzero_ps();
2822  __m256 vop24 = _mm256_setzero_ps();
2823  if (dataInd + lengths[rangeIndex] > index_size) {
2824  return false;
2825  }
2826  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2827  ++dataInd) {
2828  const int64_t idx = indices[dataInd];
2829  if (idx < 0 || idx >= data_size) {
2830  return false;
2831  }
2832  float wgt = 1.f;
2833  float bio;
2834  if (weights) {
2835  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2836  }
2837  const float* scale_bias = reinterpret_cast<const float*>(
2838  &input[idx * fused_block_size + block_size]);
2839  bio = wgt * scale_bias[1];
2840  wgt = wgt * scale_bias[0];
2841  __m256 vbio = _mm256_set1_ps(bio);
2842  __m256 vwgt = _mm256_set1_ps(wgt);
2843  const uint8_t* ip = &input[idx * fused_block_size];
2844  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2845  ? (dataInd + prefdist_T0)
2846  : dataInd;
2847  const int64_t idx_pref_T0 = indices[next_T0];
2848  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2849  return false;
2850  }
2851  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2852  vop0 = _mm256_fmadd_ps(
2853  vwgt,
2854  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2855  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2856  _mm256_add_ps(vop0, vbio));
2857  _mm_prefetch(
2858  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2859  vop8 = _mm256_fmadd_ps(
2860  vwgt,
2861  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2862  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2863  _mm256_add_ps(vop8, vbio));
2864  // skip unnecessary prefetch of (&ip_next_T0[8])
2865  vop16 = _mm256_fmadd_ps(
2866  vwgt,
2867  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2868  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2869  _mm256_add_ps(vop16, vbio));
2870  // skip unnecessary prefetch of (&ip_next_T0[16])
2871  vop24 = _mm256_fmadd_ps(
2872  vwgt,
2873  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2874  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2875  _mm256_add_ps(vop24, vbio));
2876  // skip unnecessary prefetch of (&ip_next_T0[24])
2877  }
2878  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2879  _mm256_storeu_ps(&op[0], vop0);
2880  _mm256_storeu_ps(&op[8], vop8);
2881  _mm256_storeu_ps(&op[16], vop16);
2882  _mm256_storeu_ps(&op[24], vop24);
2883  } else {
2884  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2885  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2886  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2887  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2888  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2889  }
2890  }
2891  } else if (block_size == 16) {
2892  // unrolling 2 times
2893  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2894  float* op = &out[rangeIndex * block_size];
2895  __m256 vop0 = _mm256_setzero_ps();
2896  __m256 vop8 = _mm256_setzero_ps();
2897  if (dataInd + lengths[rangeIndex] > index_size) {
2898  return false;
2899  }
2900  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2901  ++dataInd) {
2902  const int64_t idx = indices[dataInd];
2903  if (idx < 0 || idx >= data_size) {
2904  return false;
2905  }
2906  float wgt = 1.f;
2907  float bio;
2908  if (weights) {
2909  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2910  }
2911  const float* scale_bias = reinterpret_cast<const float*>(
2912  &input[idx * fused_block_size + block_size]);
2913  bio = wgt * scale_bias[1];
2914  wgt = wgt * scale_bias[0];
2915  __m256 vbio = _mm256_set1_ps(bio);
2916  __m256 vwgt = _mm256_set1_ps(wgt);
2917  const uint8_t* ip = &input[idx * fused_block_size];
2918  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2919  ? (dataInd + prefdist_T0)
2920  : dataInd;
2921  const int64_t idx_pref_T0 = indices[next_T0];
2922  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2923  return false;
2924  }
2925  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2926  vop0 = _mm256_fmadd_ps(
2927  vwgt,
2928  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2929  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2930  _mm256_add_ps(vop0, vbio));
2931  _mm_prefetch(
2932  reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2933  vop8 = _mm256_fmadd_ps(
2934  vwgt,
2935  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2936  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2937  _mm256_add_ps(vop8, vbio));
2938  // skip unnecessary prefetch of (&ip_next_T0[8])
2939  }
2940  if (!normalize_by_lengths || lengths[rangeIndex] == 0) {
2941  _mm256_storeu_ps(&op[0], vop0);
2942  _mm256_storeu_ps(&op[8], vop8);
2943  } else {
2944  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2945  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2946  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2947  }
2948  }
2949  } else {
2950  // generic code
2951  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2952  float* op = &out[rangeIndex * block_size];
2953  int64_t j = 0;
2954  for (; j + 8 <= block_size; j += 8) {
2955  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2956  }
2957  for (; j < block_size; j++) {
2958  op[j] = 0.0f;
2959  }
2960  if (dataInd + lengths[rangeIndex] > index_size) {
2961  return false;
2962  }
2963  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2964  ++dataInd) {
2965  const int64_t idx = indices[dataInd];
2966  if (idx < 0 || idx >= data_size) {
2967  return false;
2968  }
2969  float wgt = 1.f;
2970  float bio;
2971  if (weights) {
2972  wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2973  }
2974  const float* scale_bias = reinterpret_cast<const float*>(
2975  &input[idx * fused_block_size + block_size]);
2976  bio = wgt * scale_bias[1];
2977  wgt = wgt * scale_bias[0];
2978  __m256 vbio = _mm256_set1_ps(bio);
2979  __m256 vwgt = _mm256_set1_ps(wgt);
2980  const uint8_t* ip = &input[idx * fused_block_size];
2981  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2982  ? (dataInd + prefdist_T0)
2983  : dataInd;
2984  const int64_t idx_pref_T0 = indices[next_T0];
2985  if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2986  return false;
2987  }
2988  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2989  j = 0;
2990  for (; j + 8 <= block_size; j += 8) {
2991  _mm256_storeu_ps(
2992  &op[j],
2993  _mm256_fmadd_ps(
2994  vwgt,
2995  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2996  reinterpret_cast<const __m128i*>(&ip[j])))),
2997  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2998  _mm_prefetch(
2999  reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3000  }
3001  for (; j < block_size; j++) {
3002  op[j] += wgt * ((float)ip[j]) + bio;
3003  }
3004  }
3005  if (normalize_by_lengths && lengths[rangeIndex]) {
3006  float len_inv = 1.0f / lengths[rangeIndex];
3007  __m256 vlen_inv = _mm256_set1_ps(len_inv);
3008  j = 0;
3009  for (; j + 8 <= block_size; j += 8) {
3010  _mm256_storeu_ps(
3011  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3012  }
3013  for (; j < block_size; j++) {
3014  op[j] = len_inv * op[j];
3015  }
3016  }
3017  }
3018  }
3019  return dataInd == index_size;
3020 }
3021 bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_false__avx2_fma(
3022  const int64_t block_size,
3023  const int64_t output_size,
3024  const int64_t index_size,
3025  const int64_t data_size,
3026  const uint8_t* input,
3027  const int64_t* indices,
3028  const int* lengths,
3029  const float* weights,
3030  bool normalize_by_lengths,
3031  float* out) {
3032  return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<false>(
3033  block_size,
3034  output_size,
3035  index_size,
3036  data_size,
3037  input,
3038  indices,
3039  lengths,
3040  weights,
3041  normalize_by_lengths,
3042  out);
3043 }
3044 bool Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float_true__avx2_fma(
3045  const int64_t block_size,
3046  const int64_t output_size,
3047  const int64_t index_size,
3048  const int64_t data_size,
3049  const uint8_t* input,
3050  const int64_t* indices,
3051  const int* lengths,
3052  const float* weights,
3053  bool normalize_by_lengths,
3054  float* out) {
3055  return Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma<true>(
3056  block_size,
3057  output_size,
3058  index_size,
3059  data_size,
3060  input,
3061  indices,
3062  lengths,
3063  weights,
3064  normalize_by_lengths,
3065  out);
3066 }
3067 
3068 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13