Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_fused_8bit_rowwise_avx2.cc
1 
17 
24 #include <caffe2/core/common.h>
25 #include <caffe2/core/types.h>
26 #include <immintrin.h>
27 
28 namespace caffe2 {
29 
30 void Fused8BitRowwiseEmbeddingLookup_int32_t_float_float__avx2_fma(
31  const TIndex block_size,
32  const TIndex output_size,
33  const TIndex index_size,
34  const TIndex data_size,
35  const float* input,
36  const int32_t* indices,
37  const int* lengths,
38  const float* weights,
39  bool normalize_by_lengths,
40  float* out) {
41  const int32_t prefdist_T0 = 16;
42  const int32_t fused_block_size = block_size + 2;
43  if (block_size == 128) {
44  // unrolling 16 times
45  int32_t dataInd = 0;
46  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
47  float* op = &out[rangeIndex * block_size];
48  __m256 vop0 = _mm256_setzero_ps();
49  __m256 vop8 = _mm256_setzero_ps();
50  __m256 vop16 = _mm256_setzero_ps();
51  __m256 vop24 = _mm256_setzero_ps();
52  __m256 vop32 = _mm256_setzero_ps();
53  __m256 vop40 = _mm256_setzero_ps();
54  __m256 vop48 = _mm256_setzero_ps();
55  __m256 vop56 = _mm256_setzero_ps();
56  __m256 vop64 = _mm256_setzero_ps();
57  __m256 vop72 = _mm256_setzero_ps();
58  __m256 vop80 = _mm256_setzero_ps();
59  __m256 vop88 = _mm256_setzero_ps();
60  __m256 vop96 = _mm256_setzero_ps();
61  __m256 vop104 = _mm256_setzero_ps();
62  __m256 vop112 = _mm256_setzero_ps();
63  __m256 vop120 = _mm256_setzero_ps();
64  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
65  ++dataInd) {
66  const int32_t idx = indices[dataInd];
67  CAFFE_ENFORCE(
68  idx >= 0 && idx < data_size,
69  "Index ",
70  dataInd,
71  " is out of bounds: ",
72  idx,
73  ", range 0 to ",
74  data_size);
75  float wgt = 1.f;
76  if (weights) {
77  wgt = weights[dataInd];
78  }
79  __m256 vwgt = _mm256_set1_ps(wgt);
80  const float* ip = &input[idx * fused_block_size];
81  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
82  ? (dataInd + prefdist_T0)
83  : dataInd;
84  const int32_t idx_pref_T0 = indices[next_T0];
85  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
86  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
87  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
88  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
89  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
90  // skip unnecessary prefetch of (&ip_next_T0[8])
91  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
92  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
93  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
94  // skip unnecessary prefetch of (&ip_next_T0[24])
95  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
96  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
97  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
98  // skip unnecessary prefetch of (&ip_next_T0[40])
99  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
100  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
101  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
102  // skip unnecessary prefetch of (&ip_next_T0[56])
103  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
104  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
105  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
106  // skip unnecessary prefetch of (&ip_next_T0[72])
107  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
108  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
109  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
110  // skip unnecessary prefetch of (&ip_next_T0[88])
111  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
112  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
113  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
114  // skip unnecessary prefetch of (&ip_next_T0[104])
115  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
116  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
117  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
118  // skip unnecessary prefetch of (&ip_next_T0[120])
119  }
120  if (normalize_by_lengths == false) {
121  _mm256_storeu_ps(&op[0], vop0);
122  _mm256_storeu_ps(&op[8], vop8);
123  _mm256_storeu_ps(&op[16], vop16);
124  _mm256_storeu_ps(&op[24], vop24);
125  _mm256_storeu_ps(&op[32], vop32);
126  _mm256_storeu_ps(&op[40], vop40);
127  _mm256_storeu_ps(&op[48], vop48);
128  _mm256_storeu_ps(&op[56], vop56);
129  _mm256_storeu_ps(&op[64], vop64);
130  _mm256_storeu_ps(&op[72], vop72);
131  _mm256_storeu_ps(&op[80], vop80);
132  _mm256_storeu_ps(&op[88], vop88);
133  _mm256_storeu_ps(&op[96], vop96);
134  _mm256_storeu_ps(&op[104], vop104);
135  _mm256_storeu_ps(&op[112], vop112);
136  _mm256_storeu_ps(&op[120], vop120);
137  } else if (lengths[rangeIndex]) {
138  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
139  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
140  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
141  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
142  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
143  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
144  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
145  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
146  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
147  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
148  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
149  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
150  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
151  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
152  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
153  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
154  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
155  }
156  }
157  } else if (block_size == 64) {
158  // unrolling 8 times
159  int32_t dataInd = 0;
160  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
161  float* op = &out[rangeIndex * block_size];
162  __m256 vop0 = _mm256_setzero_ps();
163  __m256 vop8 = _mm256_setzero_ps();
164  __m256 vop16 = _mm256_setzero_ps();
165  __m256 vop24 = _mm256_setzero_ps();
166  __m256 vop32 = _mm256_setzero_ps();
167  __m256 vop40 = _mm256_setzero_ps();
168  __m256 vop48 = _mm256_setzero_ps();
169  __m256 vop56 = _mm256_setzero_ps();
170  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
171  ++dataInd) {
172  const int32_t idx = indices[dataInd];
173  CAFFE_ENFORCE(
174  idx >= 0 && idx < data_size,
175  "Index ",
176  dataInd,
177  " is out of bounds: ",
178  idx,
179  ", range 0 to ",
180  data_size);
181  float wgt = 1.f;
182  if (weights) {
183  wgt = weights[dataInd];
184  }
185  __m256 vwgt = _mm256_set1_ps(wgt);
186  const float* ip = &input[idx * fused_block_size];
187  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
188  ? (dataInd + prefdist_T0)
189  : dataInd;
190  const int32_t idx_pref_T0 = indices[next_T0];
191  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
192  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
193  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
194  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
195  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
196  // skip unnecessary prefetch of (&ip_next_T0[8])
197  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
198  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
199  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
200  // skip unnecessary prefetch of (&ip_next_T0[24])
201  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
202  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
203  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
204  // skip unnecessary prefetch of (&ip_next_T0[40])
205  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
206  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
207  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
208  // skip unnecessary prefetch of (&ip_next_T0[56])
209  }
210  if (normalize_by_lengths == false) {
211  _mm256_storeu_ps(&op[0], vop0);
212  _mm256_storeu_ps(&op[8], vop8);
213  _mm256_storeu_ps(&op[16], vop16);
214  _mm256_storeu_ps(&op[24], vop24);
215  _mm256_storeu_ps(&op[32], vop32);
216  _mm256_storeu_ps(&op[40], vop40);
217  _mm256_storeu_ps(&op[48], vop48);
218  _mm256_storeu_ps(&op[56], vop56);
219  } else if (lengths[rangeIndex]) {
220  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
221  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
222  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
223  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
224  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
225  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
226  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
227  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
228  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
229  }
230  }
231  } else if (block_size == 32) {
232  // unrolling 4 times
233  int32_t dataInd = 0;
234  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
235  float* op = &out[rangeIndex * block_size];
236  __m256 vop0 = _mm256_setzero_ps();
237  __m256 vop8 = _mm256_setzero_ps();
238  __m256 vop16 = _mm256_setzero_ps();
239  __m256 vop24 = _mm256_setzero_ps();
240  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
241  ++dataInd) {
242  const int32_t idx = indices[dataInd];
243  CAFFE_ENFORCE(
244  idx >= 0 && idx < data_size,
245  "Index ",
246  dataInd,
247  " is out of bounds: ",
248  idx,
249  ", range 0 to ",
250  data_size);
251  float wgt = 1.f;
252  if (weights) {
253  wgt = weights[dataInd];
254  }
255  __m256 vwgt = _mm256_set1_ps(wgt);
256  const float* ip = &input[idx * fused_block_size];
257  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
258  ? (dataInd + prefdist_T0)
259  : dataInd;
260  const int32_t idx_pref_T0 = indices[next_T0];
261  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
262  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
263  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
264  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
265  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
266  // skip unnecessary prefetch of (&ip_next_T0[8])
267  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
268  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
269  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
270  // skip unnecessary prefetch of (&ip_next_T0[24])
271  }
272  if (normalize_by_lengths == false) {
273  _mm256_storeu_ps(&op[0], vop0);
274  _mm256_storeu_ps(&op[8], vop8);
275  _mm256_storeu_ps(&op[16], vop16);
276  _mm256_storeu_ps(&op[24], vop24);
277  } else if (lengths[rangeIndex]) {
278  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
279  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
280  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
281  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
282  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
283  }
284  }
285  } else if (block_size == 16) {
286  // unrolling 2 times
287  int32_t dataInd = 0;
288  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
289  float* op = &out[rangeIndex * block_size];
290  __m256 vop0 = _mm256_setzero_ps();
291  __m256 vop8 = _mm256_setzero_ps();
292  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
293  ++dataInd) {
294  const int32_t idx = indices[dataInd];
295  CAFFE_ENFORCE(
296  idx >= 0 && idx < data_size,
297  "Index ",
298  dataInd,
299  " is out of bounds: ",
300  idx,
301  ", range 0 to ",
302  data_size);
303  float wgt = 1.f;
304  if (weights) {
305  wgt = weights[dataInd];
306  }
307  __m256 vwgt = _mm256_set1_ps(wgt);
308  const float* ip = &input[idx * fused_block_size];
309  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
310  ? (dataInd + prefdist_T0)
311  : dataInd;
312  const int32_t idx_pref_T0 = indices[next_T0];
313  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
314  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
315  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
316  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
317  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
318  // skip unnecessary prefetch of (&ip_next_T0[8])
319  }
320  if (normalize_by_lengths == false) {
321  _mm256_storeu_ps(&op[0], vop0);
322  _mm256_storeu_ps(&op[8], vop8);
323  } else if (lengths[rangeIndex]) {
324  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
325  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
326  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
327  }
328  }
329  } else {
330  // generic code
331  int32_t dataInd = 0;
332  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
333  float* op = &out[rangeIndex * block_size];
334  TIndex j = 0;
335  for (; j + 8 <= block_size; j += 8) {
336  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
337  }
338  for (; j < block_size; j++) {
339  op[j] = 0.0f;
340  }
341  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
342  ++dataInd) {
343  const int32_t idx = indices[dataInd];
344  CAFFE_ENFORCE(
345  idx >= 0 && idx < data_size,
346  "Index ",
347  dataInd,
348  " is out of bounds: ",
349  idx,
350  ", range 0 to ",
351  data_size);
352  float wgt = 1.f;
353  if (weights) {
354  wgt = weights[dataInd];
355  }
356  __m256 vwgt = _mm256_set1_ps(wgt);
357  const float* ip = &input[idx * fused_block_size];
358  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
359  ? (dataInd + prefdist_T0)
360  : dataInd;
361  const int32_t idx_pref_T0 = indices[next_T0];
362  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
363  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
364  j = 0;
365  for (; j + 8 <= block_size; j += 8) {
366  _mm256_storeu_ps(
367  &op[j],
368  _mm256_fmadd_ps(
369  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
370  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
371  }
372  for (; j < block_size; j++) {
373  op[j] += wgt * ip[j];
374  }
375  }
376  if (normalize_by_lengths && lengths[rangeIndex]) {
377  float len_inv = 1.0f / lengths[rangeIndex];
378  __m256 vlen_inv = _mm256_set1_ps(len_inv);
379  j = 0;
380  for (; j + 8 <= block_size; j += 8) {
381  _mm256_storeu_ps(
382  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
383  }
384  for (; j < block_size; j++) {
385  op[j] = len_inv * op[j];
386  }
387  }
388  }
389  }
390 }
391 
392 void Fused8BitRowwiseEmbeddingLookup_int64_t_float_float__avx2_fma(
393  const TIndex block_size,
394  const TIndex output_size,
395  const TIndex index_size,
396  const TIndex data_size,
397  const float* input,
398  const int64_t* indices,
399  const int* lengths,
400  const float* weights,
401  bool normalize_by_lengths,
402  float* out) {
403  const int64_t prefdist_T0 = 16;
404  const int64_t fused_block_size = block_size + 2;
405  if (block_size == 128) {
406  // unrolling 16 times
407  int64_t dataInd = 0;
408  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
409  float* op = &out[rangeIndex * block_size];
410  __m256 vop0 = _mm256_setzero_ps();
411  __m256 vop8 = _mm256_setzero_ps();
412  __m256 vop16 = _mm256_setzero_ps();
413  __m256 vop24 = _mm256_setzero_ps();
414  __m256 vop32 = _mm256_setzero_ps();
415  __m256 vop40 = _mm256_setzero_ps();
416  __m256 vop48 = _mm256_setzero_ps();
417  __m256 vop56 = _mm256_setzero_ps();
418  __m256 vop64 = _mm256_setzero_ps();
419  __m256 vop72 = _mm256_setzero_ps();
420  __m256 vop80 = _mm256_setzero_ps();
421  __m256 vop88 = _mm256_setzero_ps();
422  __m256 vop96 = _mm256_setzero_ps();
423  __m256 vop104 = _mm256_setzero_ps();
424  __m256 vop112 = _mm256_setzero_ps();
425  __m256 vop120 = _mm256_setzero_ps();
426  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
427  ++dataInd) {
428  const int64_t idx = indices[dataInd];
429  CAFFE_ENFORCE(
430  idx >= 0 && idx < data_size,
431  "Index ",
432  dataInd,
433  " is out of bounds: ",
434  idx,
435  ", range 0 to ",
436  data_size);
437  float wgt = 1.f;
438  if (weights) {
439  wgt = weights[dataInd];
440  }
441  __m256 vwgt = _mm256_set1_ps(wgt);
442  const float* ip = &input[idx * fused_block_size];
443  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
444  ? (dataInd + prefdist_T0)
445  : dataInd;
446  const int64_t idx_pref_T0 = indices[next_T0];
447  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
448  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
449  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
450  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
451  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
452  // skip unnecessary prefetch of (&ip_next_T0[8])
453  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
454  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
455  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
456  // skip unnecessary prefetch of (&ip_next_T0[24])
457  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
458  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
459  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
460  // skip unnecessary prefetch of (&ip_next_T0[40])
461  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
462  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
463  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
464  // skip unnecessary prefetch of (&ip_next_T0[56])
465  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
466  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
467  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
468  // skip unnecessary prefetch of (&ip_next_T0[72])
469  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
470  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
471  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
472  // skip unnecessary prefetch of (&ip_next_T0[88])
473  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
474  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
475  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
476  // skip unnecessary prefetch of (&ip_next_T0[104])
477  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
478  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
479  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
480  // skip unnecessary prefetch of (&ip_next_T0[120])
481  }
482  if (normalize_by_lengths == false) {
483  _mm256_storeu_ps(&op[0], vop0);
484  _mm256_storeu_ps(&op[8], vop8);
485  _mm256_storeu_ps(&op[16], vop16);
486  _mm256_storeu_ps(&op[24], vop24);
487  _mm256_storeu_ps(&op[32], vop32);
488  _mm256_storeu_ps(&op[40], vop40);
489  _mm256_storeu_ps(&op[48], vop48);
490  _mm256_storeu_ps(&op[56], vop56);
491  _mm256_storeu_ps(&op[64], vop64);
492  _mm256_storeu_ps(&op[72], vop72);
493  _mm256_storeu_ps(&op[80], vop80);
494  _mm256_storeu_ps(&op[88], vop88);
495  _mm256_storeu_ps(&op[96], vop96);
496  _mm256_storeu_ps(&op[104], vop104);
497  _mm256_storeu_ps(&op[112], vop112);
498  _mm256_storeu_ps(&op[120], vop120);
499  } else if (lengths[rangeIndex]) {
500  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
501  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
502  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
503  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
504  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
505  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
506  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
507  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
508  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
509  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
510  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
511  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
512  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
513  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
514  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
515  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
516  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
517  }
518  }
519  } else if (block_size == 64) {
520  // unrolling 8 times
521  int64_t dataInd = 0;
522  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
523  float* op = &out[rangeIndex * block_size];
524  __m256 vop0 = _mm256_setzero_ps();
525  __m256 vop8 = _mm256_setzero_ps();
526  __m256 vop16 = _mm256_setzero_ps();
527  __m256 vop24 = _mm256_setzero_ps();
528  __m256 vop32 = _mm256_setzero_ps();
529  __m256 vop40 = _mm256_setzero_ps();
530  __m256 vop48 = _mm256_setzero_ps();
531  __m256 vop56 = _mm256_setzero_ps();
532  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
533  ++dataInd) {
534  const int64_t idx = indices[dataInd];
535  CAFFE_ENFORCE(
536  idx >= 0 && idx < data_size,
537  "Index ",
538  dataInd,
539  " is out of bounds: ",
540  idx,
541  ", range 0 to ",
542  data_size);
543  float wgt = 1.f;
544  if (weights) {
545  wgt = weights[dataInd];
546  }
547  __m256 vwgt = _mm256_set1_ps(wgt);
548  const float* ip = &input[idx * fused_block_size];
549  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
550  ? (dataInd + prefdist_T0)
551  : dataInd;
552  const int64_t idx_pref_T0 = indices[next_T0];
553  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
554  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
555  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
556  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
557  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
558  // skip unnecessary prefetch of (&ip_next_T0[8])
559  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
560  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
561  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
562  // skip unnecessary prefetch of (&ip_next_T0[24])
563  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
564  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
565  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
566  // skip unnecessary prefetch of (&ip_next_T0[40])
567  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
568  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
569  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
570  // skip unnecessary prefetch of (&ip_next_T0[56])
571  }
572  if (normalize_by_lengths == false) {
573  _mm256_storeu_ps(&op[0], vop0);
574  _mm256_storeu_ps(&op[8], vop8);
575  _mm256_storeu_ps(&op[16], vop16);
576  _mm256_storeu_ps(&op[24], vop24);
577  _mm256_storeu_ps(&op[32], vop32);
578  _mm256_storeu_ps(&op[40], vop40);
579  _mm256_storeu_ps(&op[48], vop48);
580  _mm256_storeu_ps(&op[56], vop56);
581  } else if (lengths[rangeIndex]) {
582  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
583  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
584  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
585  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
586  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
587  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
588  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
589  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
590  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
591  }
592  }
593  } else if (block_size == 32) {
594  // unrolling 4 times
595  int64_t dataInd = 0;
596  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
597  float* op = &out[rangeIndex * block_size];
598  __m256 vop0 = _mm256_setzero_ps();
599  __m256 vop8 = _mm256_setzero_ps();
600  __m256 vop16 = _mm256_setzero_ps();
601  __m256 vop24 = _mm256_setzero_ps();
602  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
603  ++dataInd) {
604  const int64_t idx = indices[dataInd];
605  CAFFE_ENFORCE(
606  idx >= 0 && idx < data_size,
607  "Index ",
608  dataInd,
609  " is out of bounds: ",
610  idx,
611  ", range 0 to ",
612  data_size);
613  float wgt = 1.f;
614  if (weights) {
615  wgt = weights[dataInd];
616  }
617  __m256 vwgt = _mm256_set1_ps(wgt);
618  const float* ip = &input[idx * fused_block_size];
619  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
620  ? (dataInd + prefdist_T0)
621  : dataInd;
622  const int64_t idx_pref_T0 = indices[next_T0];
623  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
624  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
625  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
626  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
627  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
628  // skip unnecessary prefetch of (&ip_next_T0[8])
629  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
630  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
631  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
632  // skip unnecessary prefetch of (&ip_next_T0[24])
633  }
634  if (normalize_by_lengths == false) {
635  _mm256_storeu_ps(&op[0], vop0);
636  _mm256_storeu_ps(&op[8], vop8);
637  _mm256_storeu_ps(&op[16], vop16);
638  _mm256_storeu_ps(&op[24], vop24);
639  } else if (lengths[rangeIndex]) {
640  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
641  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
642  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
643  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
644  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
645  }
646  }
647  } else if (block_size == 16) {
648  // unrolling 2 times
649  int64_t dataInd = 0;
650  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
651  float* op = &out[rangeIndex * block_size];
652  __m256 vop0 = _mm256_setzero_ps();
653  __m256 vop8 = _mm256_setzero_ps();
654  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
655  ++dataInd) {
656  const int64_t idx = indices[dataInd];
657  CAFFE_ENFORCE(
658  idx >= 0 && idx < data_size,
659  "Index ",
660  dataInd,
661  " is out of bounds: ",
662  idx,
663  ", range 0 to ",
664  data_size);
665  float wgt = 1.f;
666  if (weights) {
667  wgt = weights[dataInd];
668  }
669  __m256 vwgt = _mm256_set1_ps(wgt);
670  const float* ip = &input[idx * fused_block_size];
671  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
672  ? (dataInd + prefdist_T0)
673  : dataInd;
674  const int64_t idx_pref_T0 = indices[next_T0];
675  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
676  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
677  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
678  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
679  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
680  // skip unnecessary prefetch of (&ip_next_T0[8])
681  }
682  if (normalize_by_lengths == false) {
683  _mm256_storeu_ps(&op[0], vop0);
684  _mm256_storeu_ps(&op[8], vop8);
685  } else if (lengths[rangeIndex]) {
686  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
687  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
688  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
689  }
690  }
691  } else {
692  // generic code
693  int64_t dataInd = 0;
694  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
695  float* op = &out[rangeIndex * block_size];
696  TIndex j = 0;
697  for (; j + 8 <= block_size; j += 8) {
698  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
699  }
700  for (; j < block_size; j++) {
701  op[j] = 0.0f;
702  }
703  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
704  ++dataInd) {
705  const int64_t idx = indices[dataInd];
706  CAFFE_ENFORCE(
707  idx >= 0 && idx < data_size,
708  "Index ",
709  dataInd,
710  " is out of bounds: ",
711  idx,
712  ", range 0 to ",
713  data_size);
714  float wgt = 1.f;
715  if (weights) {
716  wgt = weights[dataInd];
717  }
718  __m256 vwgt = _mm256_set1_ps(wgt);
719  const float* ip = &input[idx * fused_block_size];
720  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
721  ? (dataInd + prefdist_T0)
722  : dataInd;
723  const int64_t idx_pref_T0 = indices[next_T0];
724  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
725  const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
726  j = 0;
727  for (; j + 8 <= block_size; j += 8) {
728  _mm256_storeu_ps(
729  &op[j],
730  _mm256_fmadd_ps(
731  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
732  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
733  }
734  for (; j < block_size; j++) {
735  op[j] += wgt * ip[j];
736  }
737  }
738  if (normalize_by_lengths && lengths[rangeIndex]) {
739  float len_inv = 1.0f / lengths[rangeIndex];
740  __m256 vlen_inv = _mm256_set1_ps(len_inv);
741  j = 0;
742  for (; j + 8 <= block_size; j += 8) {
743  _mm256_storeu_ps(
744  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
745  }
746  for (; j < block_size; j++) {
747  op[j] = len_inv * op[j];
748  }
749  }
750  }
751  }
752 }
753 
754 void Fused8BitRowwiseEmbeddingLookup_int32_t_float16_float__avx2_fma(
755  const TIndex block_size,
756  const TIndex output_size,
757  const TIndex index_size,
758  const TIndex data_size,
759  const float16* input,
760  const int32_t* indices,
761  const int* lengths,
762  const float* weights,
763  bool normalize_by_lengths,
764  float* out) {
765  const int32_t prefdist_T0 = 16;
766  const int32_t fused_block_size = block_size + 4;
767  if (block_size == 128) {
768  // unrolling 16 times
769  int32_t dataInd = 0;
770  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
771  float* op = &out[rangeIndex * block_size];
772  __m256 vop0 = _mm256_setzero_ps();
773  __m256 vop8 = _mm256_setzero_ps();
774  __m256 vop16 = _mm256_setzero_ps();
775  __m256 vop24 = _mm256_setzero_ps();
776  __m256 vop32 = _mm256_setzero_ps();
777  __m256 vop40 = _mm256_setzero_ps();
778  __m256 vop48 = _mm256_setzero_ps();
779  __m256 vop56 = _mm256_setzero_ps();
780  __m256 vop64 = _mm256_setzero_ps();
781  __m256 vop72 = _mm256_setzero_ps();
782  __m256 vop80 = _mm256_setzero_ps();
783  __m256 vop88 = _mm256_setzero_ps();
784  __m256 vop96 = _mm256_setzero_ps();
785  __m256 vop104 = _mm256_setzero_ps();
786  __m256 vop112 = _mm256_setzero_ps();
787  __m256 vop120 = _mm256_setzero_ps();
788  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
789  ++dataInd) {
790  const int32_t idx = indices[dataInd];
791  CAFFE_ENFORCE(
792  idx >= 0 && idx < data_size,
793  "Index ",
794  dataInd,
795  " is out of bounds: ",
796  idx,
797  ", range 0 to ",
798  data_size);
799  float wgt = 1.f;
800  if (weights) {
801  wgt = weights[dataInd];
802  }
803  __m256 vwgt = _mm256_set1_ps(wgt);
804  const float16* ip = &input[idx * fused_block_size];
805  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
806  ? (dataInd + prefdist_T0)
807  : dataInd;
808  const int32_t idx_pref_T0 = indices[next_T0];
809  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
810  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
811  vop0 = _mm256_fmadd_ps(
812  vwgt,
813  _mm256_cvtph_ps(
814  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
815  vop0);
816  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
817  vop8 = _mm256_fmadd_ps(
818  vwgt,
819  _mm256_cvtph_ps(
820  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
821  vop8);
822  // skip unnecessary prefetch of (&ip_next_T0[8])
823  vop16 = _mm256_fmadd_ps(
824  vwgt,
825  _mm256_cvtph_ps(
826  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
827  vop16);
828  // skip unnecessary prefetch of (&ip_next_T0[16])
829  vop24 = _mm256_fmadd_ps(
830  vwgt,
831  _mm256_cvtph_ps(
832  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
833  vop24);
834  // skip unnecessary prefetch of (&ip_next_T0[24])
835  vop32 = _mm256_fmadd_ps(
836  vwgt,
837  _mm256_cvtph_ps(
838  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
839  vop32);
840  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
841  vop40 = _mm256_fmadd_ps(
842  vwgt,
843  _mm256_cvtph_ps(
844  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
845  vop40);
846  // skip unnecessary prefetch of (&ip_next_T0[40])
847  vop48 = _mm256_fmadd_ps(
848  vwgt,
849  _mm256_cvtph_ps(
850  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
851  vop48);
852  // skip unnecessary prefetch of (&ip_next_T0[48])
853  vop56 = _mm256_fmadd_ps(
854  vwgt,
855  _mm256_cvtph_ps(
856  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
857  vop56);
858  // skip unnecessary prefetch of (&ip_next_T0[56])
859  vop64 = _mm256_fmadd_ps(
860  vwgt,
861  _mm256_cvtph_ps(
862  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
863  vop64);
864  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
865  vop72 = _mm256_fmadd_ps(
866  vwgt,
867  _mm256_cvtph_ps(
868  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
869  vop72);
870  // skip unnecessary prefetch of (&ip_next_T0[72])
871  vop80 = _mm256_fmadd_ps(
872  vwgt,
873  _mm256_cvtph_ps(
874  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
875  vop80);
876  // skip unnecessary prefetch of (&ip_next_T0[80])
877  vop88 = _mm256_fmadd_ps(
878  vwgt,
879  _mm256_cvtph_ps(
880  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
881  vop88);
882  // skip unnecessary prefetch of (&ip_next_T0[88])
883  vop96 = _mm256_fmadd_ps(
884  vwgt,
885  _mm256_cvtph_ps(
886  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
887  vop96);
888  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
889  vop104 = _mm256_fmadd_ps(
890  vwgt,
891  _mm256_cvtph_ps(
892  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
893  vop104);
894  // skip unnecessary prefetch of (&ip_next_T0[104])
895  vop112 = _mm256_fmadd_ps(
896  vwgt,
897  _mm256_cvtph_ps(
898  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
899  vop112);
900  // skip unnecessary prefetch of (&ip_next_T0[112])
901  vop120 = _mm256_fmadd_ps(
902  vwgt,
903  _mm256_cvtph_ps(
904  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
905  vop120);
906  // skip unnecessary prefetch of (&ip_next_T0[120])
907  }
908  if (normalize_by_lengths == false) {
909  _mm256_storeu_ps(&op[0], vop0);
910  _mm256_storeu_ps(&op[8], vop8);
911  _mm256_storeu_ps(&op[16], vop16);
912  _mm256_storeu_ps(&op[24], vop24);
913  _mm256_storeu_ps(&op[32], vop32);
914  _mm256_storeu_ps(&op[40], vop40);
915  _mm256_storeu_ps(&op[48], vop48);
916  _mm256_storeu_ps(&op[56], vop56);
917  _mm256_storeu_ps(&op[64], vop64);
918  _mm256_storeu_ps(&op[72], vop72);
919  _mm256_storeu_ps(&op[80], vop80);
920  _mm256_storeu_ps(&op[88], vop88);
921  _mm256_storeu_ps(&op[96], vop96);
922  _mm256_storeu_ps(&op[104], vop104);
923  _mm256_storeu_ps(&op[112], vop112);
924  _mm256_storeu_ps(&op[120], vop120);
925  } else if (lengths[rangeIndex]) {
926  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
927  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
928  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
929  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
930  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
931  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
932  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
933  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
934  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
935  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
936  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
937  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
938  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
939  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
940  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
941  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
942  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
943  }
944  }
945  } else if (block_size == 64) {
946  // unrolling 8 times
947  int32_t dataInd = 0;
948  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
949  float* op = &out[rangeIndex * block_size];
950  __m256 vop0 = _mm256_setzero_ps();
951  __m256 vop8 = _mm256_setzero_ps();
952  __m256 vop16 = _mm256_setzero_ps();
953  __m256 vop24 = _mm256_setzero_ps();
954  __m256 vop32 = _mm256_setzero_ps();
955  __m256 vop40 = _mm256_setzero_ps();
956  __m256 vop48 = _mm256_setzero_ps();
957  __m256 vop56 = _mm256_setzero_ps();
958  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
959  ++dataInd) {
960  const int32_t idx = indices[dataInd];
961  CAFFE_ENFORCE(
962  idx >= 0 && idx < data_size,
963  "Index ",
964  dataInd,
965  " is out of bounds: ",
966  idx,
967  ", range 0 to ",
968  data_size);
969  float wgt = 1.f;
970  if (weights) {
971  wgt = weights[dataInd];
972  }
973  __m256 vwgt = _mm256_set1_ps(wgt);
974  const float16* ip = &input[idx * fused_block_size];
975  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
976  ? (dataInd + prefdist_T0)
977  : dataInd;
978  const int32_t idx_pref_T0 = indices[next_T0];
979  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
980  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
981  vop0 = _mm256_fmadd_ps(
982  vwgt,
983  _mm256_cvtph_ps(
984  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
985  vop0);
986  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
987  vop8 = _mm256_fmadd_ps(
988  vwgt,
989  _mm256_cvtph_ps(
990  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
991  vop8);
992  // skip unnecessary prefetch of (&ip_next_T0[8])
993  vop16 = _mm256_fmadd_ps(
994  vwgt,
995  _mm256_cvtph_ps(
996  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
997  vop16);
998  // skip unnecessary prefetch of (&ip_next_T0[16])
999  vop24 = _mm256_fmadd_ps(
1000  vwgt,
1001  _mm256_cvtph_ps(
1002  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1003  vop24);
1004  // skip unnecessary prefetch of (&ip_next_T0[24])
1005  vop32 = _mm256_fmadd_ps(
1006  vwgt,
1007  _mm256_cvtph_ps(
1008  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1009  vop32);
1010  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1011  vop40 = _mm256_fmadd_ps(
1012  vwgt,
1013  _mm256_cvtph_ps(
1014  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1015  vop40);
1016  // skip unnecessary prefetch of (&ip_next_T0[40])
1017  vop48 = _mm256_fmadd_ps(
1018  vwgt,
1019  _mm256_cvtph_ps(
1020  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1021  vop48);
1022  // skip unnecessary prefetch of (&ip_next_T0[48])
1023  vop56 = _mm256_fmadd_ps(
1024  vwgt,
1025  _mm256_cvtph_ps(
1026  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1027  vop56);
1028  // skip unnecessary prefetch of (&ip_next_T0[56])
1029  }
1030  if (normalize_by_lengths == false) {
1031  _mm256_storeu_ps(&op[0], vop0);
1032  _mm256_storeu_ps(&op[8], vop8);
1033  _mm256_storeu_ps(&op[16], vop16);
1034  _mm256_storeu_ps(&op[24], vop24);
1035  _mm256_storeu_ps(&op[32], vop32);
1036  _mm256_storeu_ps(&op[40], vop40);
1037  _mm256_storeu_ps(&op[48], vop48);
1038  _mm256_storeu_ps(&op[56], vop56);
1039  } else if (lengths[rangeIndex]) {
1040  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1041  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1042  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1043  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1044  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1045  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1046  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1047  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1048  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1049  }
1050  }
1051  } else if (block_size == 32) {
1052  // unrolling 4 times
1053  int32_t dataInd = 0;
1054  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1055  float* op = &out[rangeIndex * block_size];
1056  __m256 vop0 = _mm256_setzero_ps();
1057  __m256 vop8 = _mm256_setzero_ps();
1058  __m256 vop16 = _mm256_setzero_ps();
1059  __m256 vop24 = _mm256_setzero_ps();
1060  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1061  ++dataInd) {
1062  const int32_t idx = indices[dataInd];
1063  CAFFE_ENFORCE(
1064  idx >= 0 && idx < data_size,
1065  "Index ",
1066  dataInd,
1067  " is out of bounds: ",
1068  idx,
1069  ", range 0 to ",
1070  data_size);
1071  float wgt = 1.f;
1072  if (weights) {
1073  wgt = weights[dataInd];
1074  }
1075  __m256 vwgt = _mm256_set1_ps(wgt);
1076  const float16* ip = &input[idx * fused_block_size];
1077  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1078  ? (dataInd + prefdist_T0)
1079  : dataInd;
1080  const int32_t idx_pref_T0 = indices[next_T0];
1081  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1082  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1083  vop0 = _mm256_fmadd_ps(
1084  vwgt,
1085  _mm256_cvtph_ps(
1086  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1087  vop0);
1088  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1089  vop8 = _mm256_fmadd_ps(
1090  vwgt,
1091  _mm256_cvtph_ps(
1092  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1093  vop8);
1094  // skip unnecessary prefetch of (&ip_next_T0[8])
1095  vop16 = _mm256_fmadd_ps(
1096  vwgt,
1097  _mm256_cvtph_ps(
1098  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1099  vop16);
1100  // skip unnecessary prefetch of (&ip_next_T0[16])
1101  vop24 = _mm256_fmadd_ps(
1102  vwgt,
1103  _mm256_cvtph_ps(
1104  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1105  vop24);
1106  // skip unnecessary prefetch of (&ip_next_T0[24])
1107  }
1108  if (normalize_by_lengths == false) {
1109  _mm256_storeu_ps(&op[0], vop0);
1110  _mm256_storeu_ps(&op[8], vop8);
1111  _mm256_storeu_ps(&op[16], vop16);
1112  _mm256_storeu_ps(&op[24], vop24);
1113  } else if (lengths[rangeIndex]) {
1114  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1115  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1116  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1117  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1118  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1119  }
1120  }
1121  } else if (block_size == 16) {
1122  // unrolling 2 times
1123  int32_t dataInd = 0;
1124  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1125  float* op = &out[rangeIndex * block_size];
1126  __m256 vop0 = _mm256_setzero_ps();
1127  __m256 vop8 = _mm256_setzero_ps();
1128  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1129  ++dataInd) {
1130  const int32_t idx = indices[dataInd];
1131  CAFFE_ENFORCE(
1132  idx >= 0 && idx < data_size,
1133  "Index ",
1134  dataInd,
1135  " is out of bounds: ",
1136  idx,
1137  ", range 0 to ",
1138  data_size);
1139  float wgt = 1.f;
1140  if (weights) {
1141  wgt = weights[dataInd];
1142  }
1143  __m256 vwgt = _mm256_set1_ps(wgt);
1144  const float16* ip = &input[idx * fused_block_size];
1145  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1146  ? (dataInd + prefdist_T0)
1147  : dataInd;
1148  const int32_t idx_pref_T0 = indices[next_T0];
1149  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1150  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1151  vop0 = _mm256_fmadd_ps(
1152  vwgt,
1153  _mm256_cvtph_ps(
1154  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1155  vop0);
1156  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1157  vop8 = _mm256_fmadd_ps(
1158  vwgt,
1159  _mm256_cvtph_ps(
1160  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1161  vop8);
1162  // skip unnecessary prefetch of (&ip_next_T0[8])
1163  }
1164  if (normalize_by_lengths == false) {
1165  _mm256_storeu_ps(&op[0], vop0);
1166  _mm256_storeu_ps(&op[8], vop8);
1167  } else if (lengths[rangeIndex]) {
1168  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1169  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1170  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1171  }
1172  }
1173  } else {
1174  // generic code
1175  int32_t dataInd = 0;
1176  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1177  float* op = &out[rangeIndex * block_size];
1178  TIndex j = 0;
1179  for (; j + 8 <= block_size; j += 8) {
1180  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1181  }
1182  for (; j < block_size; j++) {
1183  op[j] = 0.0f;
1184  }
1185  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1186  ++dataInd) {
1187  const int32_t idx = indices[dataInd];
1188  CAFFE_ENFORCE(
1189  idx >= 0 && idx < data_size,
1190  "Index ",
1191  dataInd,
1192  " is out of bounds: ",
1193  idx,
1194  ", range 0 to ",
1195  data_size);
1196  float wgt = 1.f;
1197  if (weights) {
1198  wgt = weights[dataInd];
1199  }
1200  __m256 vwgt = _mm256_set1_ps(wgt);
1201  const float16* ip = &input[idx * fused_block_size];
1202  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1203  ? (dataInd + prefdist_T0)
1204  : dataInd;
1205  const int32_t idx_pref_T0 = indices[next_T0];
1206  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1207  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1208  j = 0;
1209  for (; j + 8 <= block_size; j += 8) {
1210  _mm256_storeu_ps(
1211  &op[j],
1212  _mm256_fmadd_ps(
1213  vwgt,
1214  _mm256_cvtph_ps(_mm_loadu_si128(
1215  reinterpret_cast<const __m128i*>(&ip[j]))),
1216  _mm256_loadu_ps(&op[j])));
1217  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1218  }
1219  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1220  for (; j < block_size; j++) {
1221  vtmp1[0] = ip[j];
1222  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1223  op[j] += wgt * ((float*)(&vtmp2))[0];
1224  }
1225  }
1226  if (normalize_by_lengths && lengths[rangeIndex]) {
1227  float len_inv = 1.0f / lengths[rangeIndex];
1228  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1229  j = 0;
1230  for (; j + 8 <= block_size; j += 8) {
1231  _mm256_storeu_ps(
1232  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1233  }
1234  for (; j < block_size; j++) {
1235  op[j] = len_inv * op[j];
1236  }
1237  }
1238  }
1239  }
1240 }
1241 
1242 void Fused8BitRowwiseEmbeddingLookup_int64_t_float16_float__avx2_fma(
1243  const TIndex block_size,
1244  const TIndex output_size,
1245  const TIndex index_size,
1246  const TIndex data_size,
1247  const float16* input,
1248  const int64_t* indices,
1249  const int* lengths,
1250  const float* weights,
1251  bool normalize_by_lengths,
1252  float* out) {
1253  const int64_t prefdist_T0 = 16;
1254  const int64_t fused_block_size = block_size + 4;
1255  if (block_size == 128) {
1256  // unrolling 16 times
1257  int64_t dataInd = 0;
1258  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1259  float* op = &out[rangeIndex * block_size];
1260  __m256 vop0 = _mm256_setzero_ps();
1261  __m256 vop8 = _mm256_setzero_ps();
1262  __m256 vop16 = _mm256_setzero_ps();
1263  __m256 vop24 = _mm256_setzero_ps();
1264  __m256 vop32 = _mm256_setzero_ps();
1265  __m256 vop40 = _mm256_setzero_ps();
1266  __m256 vop48 = _mm256_setzero_ps();
1267  __m256 vop56 = _mm256_setzero_ps();
1268  __m256 vop64 = _mm256_setzero_ps();
1269  __m256 vop72 = _mm256_setzero_ps();
1270  __m256 vop80 = _mm256_setzero_ps();
1271  __m256 vop88 = _mm256_setzero_ps();
1272  __m256 vop96 = _mm256_setzero_ps();
1273  __m256 vop104 = _mm256_setzero_ps();
1274  __m256 vop112 = _mm256_setzero_ps();
1275  __m256 vop120 = _mm256_setzero_ps();
1276  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1277  ++dataInd) {
1278  const int64_t idx = indices[dataInd];
1279  CAFFE_ENFORCE(
1280  idx >= 0 && idx < data_size,
1281  "Index ",
1282  dataInd,
1283  " is out of bounds: ",
1284  idx,
1285  ", range 0 to ",
1286  data_size);
1287  float wgt = 1.f;
1288  if (weights) {
1289  wgt = weights[dataInd];
1290  }
1291  __m256 vwgt = _mm256_set1_ps(wgt);
1292  const float16* ip = &input[idx * fused_block_size];
1293  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1294  ? (dataInd + prefdist_T0)
1295  : dataInd;
1296  const int64_t idx_pref_T0 = indices[next_T0];
1297  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1298  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1299  vop0 = _mm256_fmadd_ps(
1300  vwgt,
1301  _mm256_cvtph_ps(
1302  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1303  vop0);
1304  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1305  vop8 = _mm256_fmadd_ps(
1306  vwgt,
1307  _mm256_cvtph_ps(
1308  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1309  vop8);
1310  // skip unnecessary prefetch of (&ip_next_T0[8])
1311  vop16 = _mm256_fmadd_ps(
1312  vwgt,
1313  _mm256_cvtph_ps(
1314  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1315  vop16);
1316  // skip unnecessary prefetch of (&ip_next_T0[16])
1317  vop24 = _mm256_fmadd_ps(
1318  vwgt,
1319  _mm256_cvtph_ps(
1320  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1321  vop24);
1322  // skip unnecessary prefetch of (&ip_next_T0[24])
1323  vop32 = _mm256_fmadd_ps(
1324  vwgt,
1325  _mm256_cvtph_ps(
1326  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1327  vop32);
1328  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1329  vop40 = _mm256_fmadd_ps(
1330  vwgt,
1331  _mm256_cvtph_ps(
1332  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1333  vop40);
1334  // skip unnecessary prefetch of (&ip_next_T0[40])
1335  vop48 = _mm256_fmadd_ps(
1336  vwgt,
1337  _mm256_cvtph_ps(
1338  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1339  vop48);
1340  // skip unnecessary prefetch of (&ip_next_T0[48])
1341  vop56 = _mm256_fmadd_ps(
1342  vwgt,
1343  _mm256_cvtph_ps(
1344  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1345  vop56);
1346  // skip unnecessary prefetch of (&ip_next_T0[56])
1347  vop64 = _mm256_fmadd_ps(
1348  vwgt,
1349  _mm256_cvtph_ps(
1350  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1351  vop64);
1352  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1353  vop72 = _mm256_fmadd_ps(
1354  vwgt,
1355  _mm256_cvtph_ps(
1356  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1357  vop72);
1358  // skip unnecessary prefetch of (&ip_next_T0[72])
1359  vop80 = _mm256_fmadd_ps(
1360  vwgt,
1361  _mm256_cvtph_ps(
1362  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1363  vop80);
1364  // skip unnecessary prefetch of (&ip_next_T0[80])
1365  vop88 = _mm256_fmadd_ps(
1366  vwgt,
1367  _mm256_cvtph_ps(
1368  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1369  vop88);
1370  // skip unnecessary prefetch of (&ip_next_T0[88])
1371  vop96 = _mm256_fmadd_ps(
1372  vwgt,
1373  _mm256_cvtph_ps(
1374  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1375  vop96);
1376  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1377  vop104 = _mm256_fmadd_ps(
1378  vwgt,
1379  _mm256_cvtph_ps(
1380  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1381  vop104);
1382  // skip unnecessary prefetch of (&ip_next_T0[104])
1383  vop112 = _mm256_fmadd_ps(
1384  vwgt,
1385  _mm256_cvtph_ps(
1386  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1387  vop112);
1388  // skip unnecessary prefetch of (&ip_next_T0[112])
1389  vop120 = _mm256_fmadd_ps(
1390  vwgt,
1391  _mm256_cvtph_ps(
1392  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1393  vop120);
1394  // skip unnecessary prefetch of (&ip_next_T0[120])
1395  }
1396  if (normalize_by_lengths == false) {
1397  _mm256_storeu_ps(&op[0], vop0);
1398  _mm256_storeu_ps(&op[8], vop8);
1399  _mm256_storeu_ps(&op[16], vop16);
1400  _mm256_storeu_ps(&op[24], vop24);
1401  _mm256_storeu_ps(&op[32], vop32);
1402  _mm256_storeu_ps(&op[40], vop40);
1403  _mm256_storeu_ps(&op[48], vop48);
1404  _mm256_storeu_ps(&op[56], vop56);
1405  _mm256_storeu_ps(&op[64], vop64);
1406  _mm256_storeu_ps(&op[72], vop72);
1407  _mm256_storeu_ps(&op[80], vop80);
1408  _mm256_storeu_ps(&op[88], vop88);
1409  _mm256_storeu_ps(&op[96], vop96);
1410  _mm256_storeu_ps(&op[104], vop104);
1411  _mm256_storeu_ps(&op[112], vop112);
1412  _mm256_storeu_ps(&op[120], vop120);
1413  } else if (lengths[rangeIndex]) {
1414  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1415  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1416  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1417  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1418  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1419  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1420  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1421  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1422  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1423  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1424  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1425  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1426  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1427  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1428  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1429  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1430  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1431  }
1432  }
1433  } else if (block_size == 64) {
1434  // unrolling 8 times
1435  int64_t dataInd = 0;
1436  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1437  float* op = &out[rangeIndex * block_size];
1438  __m256 vop0 = _mm256_setzero_ps();
1439  __m256 vop8 = _mm256_setzero_ps();
1440  __m256 vop16 = _mm256_setzero_ps();
1441  __m256 vop24 = _mm256_setzero_ps();
1442  __m256 vop32 = _mm256_setzero_ps();
1443  __m256 vop40 = _mm256_setzero_ps();
1444  __m256 vop48 = _mm256_setzero_ps();
1445  __m256 vop56 = _mm256_setzero_ps();
1446  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1447  ++dataInd) {
1448  const int64_t idx = indices[dataInd];
1449  CAFFE_ENFORCE(
1450  idx >= 0 && idx < data_size,
1451  "Index ",
1452  dataInd,
1453  " is out of bounds: ",
1454  idx,
1455  ", range 0 to ",
1456  data_size);
1457  float wgt = 1.f;
1458  if (weights) {
1459  wgt = weights[dataInd];
1460  }
1461  __m256 vwgt = _mm256_set1_ps(wgt);
1462  const float16* ip = &input[idx * fused_block_size];
1463  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1464  ? (dataInd + prefdist_T0)
1465  : dataInd;
1466  const int64_t idx_pref_T0 = indices[next_T0];
1467  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1468  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1469  vop0 = _mm256_fmadd_ps(
1470  vwgt,
1471  _mm256_cvtph_ps(
1472  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1473  vop0);
1474  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1475  vop8 = _mm256_fmadd_ps(
1476  vwgt,
1477  _mm256_cvtph_ps(
1478  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1479  vop8);
1480  // skip unnecessary prefetch of (&ip_next_T0[8])
1481  vop16 = _mm256_fmadd_ps(
1482  vwgt,
1483  _mm256_cvtph_ps(
1484  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1485  vop16);
1486  // skip unnecessary prefetch of (&ip_next_T0[16])
1487  vop24 = _mm256_fmadd_ps(
1488  vwgt,
1489  _mm256_cvtph_ps(
1490  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1491  vop24);
1492  // skip unnecessary prefetch of (&ip_next_T0[24])
1493  vop32 = _mm256_fmadd_ps(
1494  vwgt,
1495  _mm256_cvtph_ps(
1496  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1497  vop32);
1498  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1499  vop40 = _mm256_fmadd_ps(
1500  vwgt,
1501  _mm256_cvtph_ps(
1502  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1503  vop40);
1504  // skip unnecessary prefetch of (&ip_next_T0[40])
1505  vop48 = _mm256_fmadd_ps(
1506  vwgt,
1507  _mm256_cvtph_ps(
1508  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1509  vop48);
1510  // skip unnecessary prefetch of (&ip_next_T0[48])
1511  vop56 = _mm256_fmadd_ps(
1512  vwgt,
1513  _mm256_cvtph_ps(
1514  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1515  vop56);
1516  // skip unnecessary prefetch of (&ip_next_T0[56])
1517  }
1518  if (normalize_by_lengths == false) {
1519  _mm256_storeu_ps(&op[0], vop0);
1520  _mm256_storeu_ps(&op[8], vop8);
1521  _mm256_storeu_ps(&op[16], vop16);
1522  _mm256_storeu_ps(&op[24], vop24);
1523  _mm256_storeu_ps(&op[32], vop32);
1524  _mm256_storeu_ps(&op[40], vop40);
1525  _mm256_storeu_ps(&op[48], vop48);
1526  _mm256_storeu_ps(&op[56], vop56);
1527  } else if (lengths[rangeIndex]) {
1528  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1529  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1530  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1531  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1532  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1533  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1534  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1535  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1536  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1537  }
1538  }
1539  } else if (block_size == 32) {
1540  // unrolling 4 times
1541  int64_t dataInd = 0;
1542  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1543  float* op = &out[rangeIndex * block_size];
1544  __m256 vop0 = _mm256_setzero_ps();
1545  __m256 vop8 = _mm256_setzero_ps();
1546  __m256 vop16 = _mm256_setzero_ps();
1547  __m256 vop24 = _mm256_setzero_ps();
1548  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1549  ++dataInd) {
1550  const int64_t idx = indices[dataInd];
1551  CAFFE_ENFORCE(
1552  idx >= 0 && idx < data_size,
1553  "Index ",
1554  dataInd,
1555  " is out of bounds: ",
1556  idx,
1557  ", range 0 to ",
1558  data_size);
1559  float wgt = 1.f;
1560  if (weights) {
1561  wgt = weights[dataInd];
1562  }
1563  __m256 vwgt = _mm256_set1_ps(wgt);
1564  const float16* ip = &input[idx * fused_block_size];
1565  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1566  ? (dataInd + prefdist_T0)
1567  : dataInd;
1568  const int64_t idx_pref_T0 = indices[next_T0];
1569  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1570  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1571  vop0 = _mm256_fmadd_ps(
1572  vwgt,
1573  _mm256_cvtph_ps(
1574  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1575  vop0);
1576  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1577  vop8 = _mm256_fmadd_ps(
1578  vwgt,
1579  _mm256_cvtph_ps(
1580  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1581  vop8);
1582  // skip unnecessary prefetch of (&ip_next_T0[8])
1583  vop16 = _mm256_fmadd_ps(
1584  vwgt,
1585  _mm256_cvtph_ps(
1586  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1587  vop16);
1588  // skip unnecessary prefetch of (&ip_next_T0[16])
1589  vop24 = _mm256_fmadd_ps(
1590  vwgt,
1591  _mm256_cvtph_ps(
1592  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1593  vop24);
1594  // skip unnecessary prefetch of (&ip_next_T0[24])
1595  }
1596  if (normalize_by_lengths == false) {
1597  _mm256_storeu_ps(&op[0], vop0);
1598  _mm256_storeu_ps(&op[8], vop8);
1599  _mm256_storeu_ps(&op[16], vop16);
1600  _mm256_storeu_ps(&op[24], vop24);
1601  } else if (lengths[rangeIndex]) {
1602  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1603  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1604  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1605  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1606  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1607  }
1608  }
1609  } else if (block_size == 16) {
1610  // unrolling 2 times
1611  int64_t dataInd = 0;
1612  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1613  float* op = &out[rangeIndex * block_size];
1614  __m256 vop0 = _mm256_setzero_ps();
1615  __m256 vop8 = _mm256_setzero_ps();
1616  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1617  ++dataInd) {
1618  const int64_t idx = indices[dataInd];
1619  CAFFE_ENFORCE(
1620  idx >= 0 && idx < data_size,
1621  "Index ",
1622  dataInd,
1623  " is out of bounds: ",
1624  idx,
1625  ", range 0 to ",
1626  data_size);
1627  float wgt = 1.f;
1628  if (weights) {
1629  wgt = weights[dataInd];
1630  }
1631  __m256 vwgt = _mm256_set1_ps(wgt);
1632  const float16* ip = &input[idx * fused_block_size];
1633  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1634  ? (dataInd + prefdist_T0)
1635  : dataInd;
1636  const int64_t idx_pref_T0 = indices[next_T0];
1637  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1638  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1639  vop0 = _mm256_fmadd_ps(
1640  vwgt,
1641  _mm256_cvtph_ps(
1642  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1643  vop0);
1644  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1645  vop8 = _mm256_fmadd_ps(
1646  vwgt,
1647  _mm256_cvtph_ps(
1648  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1649  vop8);
1650  // skip unnecessary prefetch of (&ip_next_T0[8])
1651  }
1652  if (normalize_by_lengths == false) {
1653  _mm256_storeu_ps(&op[0], vop0);
1654  _mm256_storeu_ps(&op[8], vop8);
1655  } else if (lengths[rangeIndex]) {
1656  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1657  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1658  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1659  }
1660  }
1661  } else {
1662  // generic code
1663  int64_t dataInd = 0;
1664  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1665  float* op = &out[rangeIndex * block_size];
1666  TIndex j = 0;
1667  for (; j + 8 <= block_size; j += 8) {
1668  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1669  }
1670  for (; j < block_size; j++) {
1671  op[j] = 0.0f;
1672  }
1673  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1674  ++dataInd) {
1675  const int64_t idx = indices[dataInd];
1676  CAFFE_ENFORCE(
1677  idx >= 0 && idx < data_size,
1678  "Index ",
1679  dataInd,
1680  " is out of bounds: ",
1681  idx,
1682  ", range 0 to ",
1683  data_size);
1684  float wgt = 1.f;
1685  if (weights) {
1686  wgt = weights[dataInd];
1687  }
1688  __m256 vwgt = _mm256_set1_ps(wgt);
1689  const float16* ip = &input[idx * fused_block_size];
1690  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1691  ? (dataInd + prefdist_T0)
1692  : dataInd;
1693  const int64_t idx_pref_T0 = indices[next_T0];
1694  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1695  const float16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1696  j = 0;
1697  for (; j + 8 <= block_size; j += 8) {
1698  _mm256_storeu_ps(
1699  &op[j],
1700  _mm256_fmadd_ps(
1701  vwgt,
1702  _mm256_cvtph_ps(_mm_loadu_si128(
1703  reinterpret_cast<const __m128i*>(&ip[j]))),
1704  _mm256_loadu_ps(&op[j])));
1705  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1706  }
1707  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1708  for (; j < block_size; j++) {
1709  vtmp1[0] = ip[j];
1710  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1711  op[j] += wgt * ((float*)(&vtmp2))[0];
1712  }
1713  }
1714  if (normalize_by_lengths && lengths[rangeIndex]) {
1715  float len_inv = 1.0f / lengths[rangeIndex];
1716  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1717  j = 0;
1718  for (; j + 8 <= block_size; j += 8) {
1719  _mm256_storeu_ps(
1720  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1721  }
1722  for (; j < block_size; j++) {
1723  op[j] = len_inv * op[j];
1724  }
1725  }
1726  }
1727  }
1728 }
1729 
1730 void Fused8BitRowwiseEmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1731  const TIndex block_size,
1732  const TIndex output_size,
1733  const TIndex index_size,
1734  const TIndex data_size,
1735  const uint8_t* input,
1736  const int32_t* indices,
1737  const int* lengths,
1738  const float* weights,
1739  bool normalize_by_lengths,
1740  float* out) {
1741  const int32_t prefdist_T0 = 16;
1742  const int32_t fused_block_size = block_size + 8;
1743  if (block_size == 128) {
1744  // unrolling 16 times
1745  int32_t dataInd = 0;
1746  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1747  float* op = &out[rangeIndex * block_size];
1748  __m256 vop0 = _mm256_setzero_ps();
1749  __m256 vop8 = _mm256_setzero_ps();
1750  __m256 vop16 = _mm256_setzero_ps();
1751  __m256 vop24 = _mm256_setzero_ps();
1752  __m256 vop32 = _mm256_setzero_ps();
1753  __m256 vop40 = _mm256_setzero_ps();
1754  __m256 vop48 = _mm256_setzero_ps();
1755  __m256 vop56 = _mm256_setzero_ps();
1756  __m256 vop64 = _mm256_setzero_ps();
1757  __m256 vop72 = _mm256_setzero_ps();
1758  __m256 vop80 = _mm256_setzero_ps();
1759  __m256 vop88 = _mm256_setzero_ps();
1760  __m256 vop96 = _mm256_setzero_ps();
1761  __m256 vop104 = _mm256_setzero_ps();
1762  __m256 vop112 = _mm256_setzero_ps();
1763  __m256 vop120 = _mm256_setzero_ps();
1764  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1765  ++dataInd) {
1766  const int32_t idx = indices[dataInd];
1767  CAFFE_ENFORCE(
1768  idx >= 0 && idx < data_size,
1769  "Index ",
1770  dataInd,
1771  " is out of bounds: ",
1772  idx,
1773  ", range 0 to ",
1774  data_size);
1775  float wgt = 1.f;
1776  float bio;
1777  if (weights) {
1778  wgt = weights[dataInd];
1779  }
1780  const float* scale_bias = reinterpret_cast<const float*>(
1781  &input[idx * fused_block_size + block_size]);
1782  bio = wgt * scale_bias[1];
1783  wgt = wgt * scale_bias[0];
1784  __m256 vbio = _mm256_set1_ps(bio);
1785  __m256 vwgt = _mm256_set1_ps(wgt);
1786  const uint8_t* ip = &input[idx * fused_block_size];
1787  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1788  ? (dataInd + prefdist_T0)
1789  : dataInd;
1790  const int32_t idx_pref_T0 = indices[next_T0];
1791  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1792  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1793  vop0 = _mm256_fmadd_ps(
1794  vwgt,
1795  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1796  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1797  _mm256_add_ps(vop0, vbio));
1798  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1799  vop8 = _mm256_fmadd_ps(
1800  vwgt,
1801  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1802  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1803  _mm256_add_ps(vop8, vbio));
1804  // skip unnecessary prefetch of (&ip_next_T0[8])
1805  vop16 = _mm256_fmadd_ps(
1806  vwgt,
1807  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1808  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1809  _mm256_add_ps(vop16, vbio));
1810  // skip unnecessary prefetch of (&ip_next_T0[16])
1811  vop24 = _mm256_fmadd_ps(
1812  vwgt,
1813  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1814  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1815  _mm256_add_ps(vop24, vbio));
1816  // skip unnecessary prefetch of (&ip_next_T0[24])
1817  vop32 = _mm256_fmadd_ps(
1818  vwgt,
1819  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1820  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1821  _mm256_add_ps(vop32, vbio));
1822  // skip unnecessary prefetch of (&ip_next_T0[32])
1823  vop40 = _mm256_fmadd_ps(
1824  vwgt,
1825  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1826  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
1827  _mm256_add_ps(vop40, vbio));
1828  // skip unnecessary prefetch of (&ip_next_T0[40])
1829  vop48 = _mm256_fmadd_ps(
1830  vwgt,
1831  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1832  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
1833  _mm256_add_ps(vop48, vbio));
1834  // skip unnecessary prefetch of (&ip_next_T0[48])
1835  vop56 = _mm256_fmadd_ps(
1836  vwgt,
1837  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1838  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
1839  _mm256_add_ps(vop56, vbio));
1840  // skip unnecessary prefetch of (&ip_next_T0[56])
1841  vop64 = _mm256_fmadd_ps(
1842  vwgt,
1843  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1844  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
1845  _mm256_add_ps(vop64, vbio));
1846  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1847  vop72 = _mm256_fmadd_ps(
1848  vwgt,
1849  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1850  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
1851  _mm256_add_ps(vop72, vbio));
1852  // skip unnecessary prefetch of (&ip_next_T0[72])
1853  vop80 = _mm256_fmadd_ps(
1854  vwgt,
1855  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1856  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
1857  _mm256_add_ps(vop80, vbio));
1858  // skip unnecessary prefetch of (&ip_next_T0[80])
1859  vop88 = _mm256_fmadd_ps(
1860  vwgt,
1861  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1862  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
1863  _mm256_add_ps(vop88, vbio));
1864  // skip unnecessary prefetch of (&ip_next_T0[88])
1865  vop96 = _mm256_fmadd_ps(
1866  vwgt,
1867  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1868  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
1869  _mm256_add_ps(vop96, vbio));
1870  // skip unnecessary prefetch of (&ip_next_T0[96])
1871  vop104 = _mm256_fmadd_ps(
1872  vwgt,
1873  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1874  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
1875  _mm256_add_ps(vop104, vbio));
1876  // skip unnecessary prefetch of (&ip_next_T0[104])
1877  vop112 = _mm256_fmadd_ps(
1878  vwgt,
1879  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1880  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
1881  _mm256_add_ps(vop112, vbio));
1882  // skip unnecessary prefetch of (&ip_next_T0[112])
1883  vop120 = _mm256_fmadd_ps(
1884  vwgt,
1885  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1886  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
1887  _mm256_add_ps(vop120, vbio));
1888  // skip unnecessary prefetch of (&ip_next_T0[120])
1889  }
1890  if (normalize_by_lengths == false) {
1891  _mm256_storeu_ps(&op[0], vop0);
1892  _mm256_storeu_ps(&op[8], vop8);
1893  _mm256_storeu_ps(&op[16], vop16);
1894  _mm256_storeu_ps(&op[24], vop24);
1895  _mm256_storeu_ps(&op[32], vop32);
1896  _mm256_storeu_ps(&op[40], vop40);
1897  _mm256_storeu_ps(&op[48], vop48);
1898  _mm256_storeu_ps(&op[56], vop56);
1899  _mm256_storeu_ps(&op[64], vop64);
1900  _mm256_storeu_ps(&op[72], vop72);
1901  _mm256_storeu_ps(&op[80], vop80);
1902  _mm256_storeu_ps(&op[88], vop88);
1903  _mm256_storeu_ps(&op[96], vop96);
1904  _mm256_storeu_ps(&op[104], vop104);
1905  _mm256_storeu_ps(&op[112], vop112);
1906  _mm256_storeu_ps(&op[120], vop120);
1907  } else if (lengths[rangeIndex]) {
1908  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1909  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1910  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1911  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1912  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1913  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1914  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1915  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1916  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1917  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1918  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1919  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1920  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1921  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1922  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1923  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1924  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1925  }
1926  }
1927  } else if (block_size == 64) {
1928  // unrolling 8 times
1929  int32_t dataInd = 0;
1930  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1931  float* op = &out[rangeIndex * block_size];
1932  __m256 vop0 = _mm256_setzero_ps();
1933  __m256 vop8 = _mm256_setzero_ps();
1934  __m256 vop16 = _mm256_setzero_ps();
1935  __m256 vop24 = _mm256_setzero_ps();
1936  __m256 vop32 = _mm256_setzero_ps();
1937  __m256 vop40 = _mm256_setzero_ps();
1938  __m256 vop48 = _mm256_setzero_ps();
1939  __m256 vop56 = _mm256_setzero_ps();
1940  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1941  ++dataInd) {
1942  const int32_t idx = indices[dataInd];
1943  CAFFE_ENFORCE(
1944  idx >= 0 && idx < data_size,
1945  "Index ",
1946  dataInd,
1947  " is out of bounds: ",
1948  idx,
1949  ", range 0 to ",
1950  data_size);
1951  float wgt = 1.f;
1952  float bio;
1953  if (weights) {
1954  wgt = weights[dataInd];
1955  }
1956  const float* scale_bias = reinterpret_cast<const float*>(
1957  &input[idx * fused_block_size + block_size]);
1958  bio = wgt * scale_bias[1];
1959  wgt = wgt * scale_bias[0];
1960  __m256 vbio = _mm256_set1_ps(bio);
1961  __m256 vwgt = _mm256_set1_ps(wgt);
1962  const uint8_t* ip = &input[idx * fused_block_size];
1963  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1964  ? (dataInd + prefdist_T0)
1965  : dataInd;
1966  const int32_t idx_pref_T0 = indices[next_T0];
1967  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1968  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1969  vop0 = _mm256_fmadd_ps(
1970  vwgt,
1971  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1972  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1973  _mm256_add_ps(vop0, vbio));
1974  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1975  vop8 = _mm256_fmadd_ps(
1976  vwgt,
1977  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1978  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1979  _mm256_add_ps(vop8, vbio));
1980  // skip unnecessary prefetch of (&ip_next_T0[8])
1981  vop16 = _mm256_fmadd_ps(
1982  vwgt,
1983  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1984  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1985  _mm256_add_ps(vop16, vbio));
1986  // skip unnecessary prefetch of (&ip_next_T0[16])
1987  vop24 = _mm256_fmadd_ps(
1988  vwgt,
1989  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1990  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1991  _mm256_add_ps(vop24, vbio));
1992  // skip unnecessary prefetch of (&ip_next_T0[24])
1993  vop32 = _mm256_fmadd_ps(
1994  vwgt,
1995  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1996  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1997  _mm256_add_ps(vop32, vbio));
1998  // skip unnecessary prefetch of (&ip_next_T0[32])
1999  vop40 = _mm256_fmadd_ps(
2000  vwgt,
2001  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2002  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2003  _mm256_add_ps(vop40, vbio));
2004  // skip unnecessary prefetch of (&ip_next_T0[40])
2005  vop48 = _mm256_fmadd_ps(
2006  vwgt,
2007  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2008  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2009  _mm256_add_ps(vop48, vbio));
2010  // skip unnecessary prefetch of (&ip_next_T0[48])
2011  vop56 = _mm256_fmadd_ps(
2012  vwgt,
2013  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2014  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2015  _mm256_add_ps(vop56, vbio));
2016  // skip unnecessary prefetch of (&ip_next_T0[56])
2017  }
2018  if (normalize_by_lengths == false) {
2019  _mm256_storeu_ps(&op[0], vop0);
2020  _mm256_storeu_ps(&op[8], vop8);
2021  _mm256_storeu_ps(&op[16], vop16);
2022  _mm256_storeu_ps(&op[24], vop24);
2023  _mm256_storeu_ps(&op[32], vop32);
2024  _mm256_storeu_ps(&op[40], vop40);
2025  _mm256_storeu_ps(&op[48], vop48);
2026  _mm256_storeu_ps(&op[56], vop56);
2027  } else if (lengths[rangeIndex]) {
2028  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2029  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2030  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2031  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2032  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2033  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2034  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2035  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2036  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2037  }
2038  }
2039  } else if (block_size == 32) {
2040  // unrolling 4 times
2041  int32_t dataInd = 0;
2042  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2043  float* op = &out[rangeIndex * block_size];
2044  __m256 vop0 = _mm256_setzero_ps();
2045  __m256 vop8 = _mm256_setzero_ps();
2046  __m256 vop16 = _mm256_setzero_ps();
2047  __m256 vop24 = _mm256_setzero_ps();
2048  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2049  ++dataInd) {
2050  const int32_t idx = indices[dataInd];
2051  CAFFE_ENFORCE(
2052  idx >= 0 && idx < data_size,
2053  "Index ",
2054  dataInd,
2055  " is out of bounds: ",
2056  idx,
2057  ", range 0 to ",
2058  data_size);
2059  float wgt = 1.f;
2060  float bio;
2061  if (weights) {
2062  wgt = weights[dataInd];
2063  }
2064  const float* scale_bias = reinterpret_cast<const float*>(
2065  &input[idx * fused_block_size + block_size]);
2066  bio = wgt * scale_bias[1];
2067  wgt = wgt * scale_bias[0];
2068  __m256 vbio = _mm256_set1_ps(bio);
2069  __m256 vwgt = _mm256_set1_ps(wgt);
2070  const uint8_t* ip = &input[idx * fused_block_size];
2071  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2072  ? (dataInd + prefdist_T0)
2073  : dataInd;
2074  const int32_t idx_pref_T0 = indices[next_T0];
2075  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2076  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2077  vop0 = _mm256_fmadd_ps(
2078  vwgt,
2079  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2080  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2081  _mm256_add_ps(vop0, vbio));
2082  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2083  vop8 = _mm256_fmadd_ps(
2084  vwgt,
2085  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2086  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2087  _mm256_add_ps(vop8, vbio));
2088  // skip unnecessary prefetch of (&ip_next_T0[8])
2089  vop16 = _mm256_fmadd_ps(
2090  vwgt,
2091  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2092  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2093  _mm256_add_ps(vop16, vbio));
2094  // skip unnecessary prefetch of (&ip_next_T0[16])
2095  vop24 = _mm256_fmadd_ps(
2096  vwgt,
2097  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2098  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2099  _mm256_add_ps(vop24, vbio));
2100  // skip unnecessary prefetch of (&ip_next_T0[24])
2101  }
2102  if (normalize_by_lengths == false) {
2103  _mm256_storeu_ps(&op[0], vop0);
2104  _mm256_storeu_ps(&op[8], vop8);
2105  _mm256_storeu_ps(&op[16], vop16);
2106  _mm256_storeu_ps(&op[24], vop24);
2107  } else if (lengths[rangeIndex]) {
2108  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2109  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2110  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2111  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2112  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2113  }
2114  }
2115  } else if (block_size == 16) {
2116  // unrolling 2 times
2117  int32_t dataInd = 0;
2118  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2119  float* op = &out[rangeIndex * block_size];
2120  __m256 vop0 = _mm256_setzero_ps();
2121  __m256 vop8 = _mm256_setzero_ps();
2122  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2123  ++dataInd) {
2124  const int32_t idx = indices[dataInd];
2125  CAFFE_ENFORCE(
2126  idx >= 0 && idx < data_size,
2127  "Index ",
2128  dataInd,
2129  " is out of bounds: ",
2130  idx,
2131  ", range 0 to ",
2132  data_size);
2133  float wgt = 1.f;
2134  float bio;
2135  if (weights) {
2136  wgt = weights[dataInd];
2137  }
2138  const float* scale_bias = reinterpret_cast<const float*>(
2139  &input[idx * fused_block_size + block_size]);
2140  bio = wgt * scale_bias[1];
2141  wgt = wgt * scale_bias[0];
2142  __m256 vbio = _mm256_set1_ps(bio);
2143  __m256 vwgt = _mm256_set1_ps(wgt);
2144  const uint8_t* ip = &input[idx * fused_block_size];
2145  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2146  ? (dataInd + prefdist_T0)
2147  : dataInd;
2148  const int32_t idx_pref_T0 = indices[next_T0];
2149  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2150  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2151  vop0 = _mm256_fmadd_ps(
2152  vwgt,
2153  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2154  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2155  _mm256_add_ps(vop0, vbio));
2156  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2157  vop8 = _mm256_fmadd_ps(
2158  vwgt,
2159  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2160  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2161  _mm256_add_ps(vop8, vbio));
2162  // skip unnecessary prefetch of (&ip_next_T0[8])
2163  }
2164  if (normalize_by_lengths == false) {
2165  _mm256_storeu_ps(&op[0], vop0);
2166  _mm256_storeu_ps(&op[8], vop8);
2167  } else if (lengths[rangeIndex]) {
2168  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2169  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2170  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2171  }
2172  }
2173  } else {
2174  // generic code
2175  int32_t dataInd = 0;
2176  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2177  float* op = &out[rangeIndex * block_size];
2178  TIndex j = 0;
2179  for (; j + 8 <= block_size; j += 8) {
2180  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2181  }
2182  for (; j < block_size; j++) {
2183  op[j] = 0.0f;
2184  }
2185  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2186  ++dataInd) {
2187  const int32_t idx = indices[dataInd];
2188  CAFFE_ENFORCE(
2189  idx >= 0 && idx < data_size,
2190  "Index ",
2191  dataInd,
2192  " is out of bounds: ",
2193  idx,
2194  ", range 0 to ",
2195  data_size);
2196  float wgt = 1.f;
2197  float bio;
2198  if (weights) {
2199  wgt = weights[dataInd];
2200  }
2201  const float* scale_bias = reinterpret_cast<const float*>(
2202  &input[idx * fused_block_size + block_size]);
2203  bio = wgt * scale_bias[1];
2204  wgt = wgt * scale_bias[0];
2205  __m256 vbio = _mm256_set1_ps(bio);
2206  __m256 vwgt = _mm256_set1_ps(wgt);
2207  const uint8_t* ip = &input[idx * fused_block_size];
2208  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2209  ? (dataInd + prefdist_T0)
2210  : dataInd;
2211  const int32_t idx_pref_T0 = indices[next_T0];
2212  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2213  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2214  j = 0;
2215  for (; j + 8 <= block_size; j += 8) {
2216  _mm256_storeu_ps(
2217  &op[j],
2218  _mm256_fmadd_ps(
2219  vwgt,
2220  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2221  reinterpret_cast<const __m128i*>(&ip[j])))),
2222  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2223  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2224  }
2225  for (; j < block_size; j++) {
2226  op[j] += wgt * ((float)ip[j]) + bio;
2227  }
2228  }
2229  if (normalize_by_lengths && lengths[rangeIndex]) {
2230  float len_inv = 1.0f / lengths[rangeIndex];
2231  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2232  j = 0;
2233  for (; j + 8 <= block_size; j += 8) {
2234  _mm256_storeu_ps(
2235  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2236  }
2237  for (; j < block_size; j++) {
2238  op[j] = len_inv * op[j];
2239  }
2240  }
2241  }
2242  }
2243 }
2244 
2245 void Fused8BitRowwiseEmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2246  const TIndex block_size,
2247  const TIndex output_size,
2248  const TIndex index_size,
2249  const TIndex data_size,
2250  const uint8_t* input,
2251  const int64_t* indices,
2252  const int* lengths,
2253  const float* weights,
2254  bool normalize_by_lengths,
2255  float* out) {
2256  const int64_t prefdist_T0 = 16;
2257  const int64_t fused_block_size = block_size + 8;
2258  if (block_size == 128) {
2259  // unrolling 16 times
2260  int64_t dataInd = 0;
2261  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2262  float* op = &out[rangeIndex * block_size];
2263  __m256 vop0 = _mm256_setzero_ps();
2264  __m256 vop8 = _mm256_setzero_ps();
2265  __m256 vop16 = _mm256_setzero_ps();
2266  __m256 vop24 = _mm256_setzero_ps();
2267  __m256 vop32 = _mm256_setzero_ps();
2268  __m256 vop40 = _mm256_setzero_ps();
2269  __m256 vop48 = _mm256_setzero_ps();
2270  __m256 vop56 = _mm256_setzero_ps();
2271  __m256 vop64 = _mm256_setzero_ps();
2272  __m256 vop72 = _mm256_setzero_ps();
2273  __m256 vop80 = _mm256_setzero_ps();
2274  __m256 vop88 = _mm256_setzero_ps();
2275  __m256 vop96 = _mm256_setzero_ps();
2276  __m256 vop104 = _mm256_setzero_ps();
2277  __m256 vop112 = _mm256_setzero_ps();
2278  __m256 vop120 = _mm256_setzero_ps();
2279  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2280  ++dataInd) {
2281  const int64_t idx = indices[dataInd];
2282  CAFFE_ENFORCE(
2283  idx >= 0 && idx < data_size,
2284  "Index ",
2285  dataInd,
2286  " is out of bounds: ",
2287  idx,
2288  ", range 0 to ",
2289  data_size);
2290  float wgt = 1.f;
2291  float bio;
2292  if (weights) {
2293  wgt = weights[dataInd];
2294  }
2295  const float* scale_bias = reinterpret_cast<const float*>(
2296  &input[idx * fused_block_size + block_size]);
2297  bio = wgt * scale_bias[1];
2298  wgt = wgt * scale_bias[0];
2299  __m256 vbio = _mm256_set1_ps(bio);
2300  __m256 vwgt = _mm256_set1_ps(wgt);
2301  const uint8_t* ip = &input[idx * fused_block_size];
2302  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2303  ? (dataInd + prefdist_T0)
2304  : dataInd;
2305  const int64_t idx_pref_T0 = indices[next_T0];
2306  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2307  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2308  vop0 = _mm256_fmadd_ps(
2309  vwgt,
2310  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2311  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2312  _mm256_add_ps(vop0, vbio));
2313  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2314  vop8 = _mm256_fmadd_ps(
2315  vwgt,
2316  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2317  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2318  _mm256_add_ps(vop8, vbio));
2319  // skip unnecessary prefetch of (&ip_next_T0[8])
2320  vop16 = _mm256_fmadd_ps(
2321  vwgt,
2322  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2323  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2324  _mm256_add_ps(vop16, vbio));
2325  // skip unnecessary prefetch of (&ip_next_T0[16])
2326  vop24 = _mm256_fmadd_ps(
2327  vwgt,
2328  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2329  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2330  _mm256_add_ps(vop24, vbio));
2331  // skip unnecessary prefetch of (&ip_next_T0[24])
2332  vop32 = _mm256_fmadd_ps(
2333  vwgt,
2334  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2335  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2336  _mm256_add_ps(vop32, vbio));
2337  // skip unnecessary prefetch of (&ip_next_T0[32])
2338  vop40 = _mm256_fmadd_ps(
2339  vwgt,
2340  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2341  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2342  _mm256_add_ps(vop40, vbio));
2343  // skip unnecessary prefetch of (&ip_next_T0[40])
2344  vop48 = _mm256_fmadd_ps(
2345  vwgt,
2346  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2347  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2348  _mm256_add_ps(vop48, vbio));
2349  // skip unnecessary prefetch of (&ip_next_T0[48])
2350  vop56 = _mm256_fmadd_ps(
2351  vwgt,
2352  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2353  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2354  _mm256_add_ps(vop56, vbio));
2355  // skip unnecessary prefetch of (&ip_next_T0[56])
2356  vop64 = _mm256_fmadd_ps(
2357  vwgt,
2358  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2359  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2360  _mm256_add_ps(vop64, vbio));
2361  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2362  vop72 = _mm256_fmadd_ps(
2363  vwgt,
2364  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2365  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2366  _mm256_add_ps(vop72, vbio));
2367  // skip unnecessary prefetch of (&ip_next_T0[72])
2368  vop80 = _mm256_fmadd_ps(
2369  vwgt,
2370  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2371  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2372  _mm256_add_ps(vop80, vbio));
2373  // skip unnecessary prefetch of (&ip_next_T0[80])
2374  vop88 = _mm256_fmadd_ps(
2375  vwgt,
2376  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2377  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2378  _mm256_add_ps(vop88, vbio));
2379  // skip unnecessary prefetch of (&ip_next_T0[88])
2380  vop96 = _mm256_fmadd_ps(
2381  vwgt,
2382  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2383  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2384  _mm256_add_ps(vop96, vbio));
2385  // skip unnecessary prefetch of (&ip_next_T0[96])
2386  vop104 = _mm256_fmadd_ps(
2387  vwgt,
2388  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2389  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2390  _mm256_add_ps(vop104, vbio));
2391  // skip unnecessary prefetch of (&ip_next_T0[104])
2392  vop112 = _mm256_fmadd_ps(
2393  vwgt,
2394  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2395  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2396  _mm256_add_ps(vop112, vbio));
2397  // skip unnecessary prefetch of (&ip_next_T0[112])
2398  vop120 = _mm256_fmadd_ps(
2399  vwgt,
2400  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2401  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2402  _mm256_add_ps(vop120, vbio));
2403  // skip unnecessary prefetch of (&ip_next_T0[120])
2404  }
2405  if (normalize_by_lengths == false) {
2406  _mm256_storeu_ps(&op[0], vop0);
2407  _mm256_storeu_ps(&op[8], vop8);
2408  _mm256_storeu_ps(&op[16], vop16);
2409  _mm256_storeu_ps(&op[24], vop24);
2410  _mm256_storeu_ps(&op[32], vop32);
2411  _mm256_storeu_ps(&op[40], vop40);
2412  _mm256_storeu_ps(&op[48], vop48);
2413  _mm256_storeu_ps(&op[56], vop56);
2414  _mm256_storeu_ps(&op[64], vop64);
2415  _mm256_storeu_ps(&op[72], vop72);
2416  _mm256_storeu_ps(&op[80], vop80);
2417  _mm256_storeu_ps(&op[88], vop88);
2418  _mm256_storeu_ps(&op[96], vop96);
2419  _mm256_storeu_ps(&op[104], vop104);
2420  _mm256_storeu_ps(&op[112], vop112);
2421  _mm256_storeu_ps(&op[120], vop120);
2422  } else if (lengths[rangeIndex]) {
2423  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2424  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2425  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2426  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2427  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2428  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2429  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2430  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2431  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2432  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2433  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2434  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2435  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2436  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2437  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2438  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2439  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2440  }
2441  }
2442  } else if (block_size == 64) {
2443  // unrolling 8 times
2444  int64_t dataInd = 0;
2445  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2446  float* op = &out[rangeIndex * block_size];
2447  __m256 vop0 = _mm256_setzero_ps();
2448  __m256 vop8 = _mm256_setzero_ps();
2449  __m256 vop16 = _mm256_setzero_ps();
2450  __m256 vop24 = _mm256_setzero_ps();
2451  __m256 vop32 = _mm256_setzero_ps();
2452  __m256 vop40 = _mm256_setzero_ps();
2453  __m256 vop48 = _mm256_setzero_ps();
2454  __m256 vop56 = _mm256_setzero_ps();
2455  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2456  ++dataInd) {
2457  const int64_t idx = indices[dataInd];
2458  CAFFE_ENFORCE(
2459  idx >= 0 && idx < data_size,
2460  "Index ",
2461  dataInd,
2462  " is out of bounds: ",
2463  idx,
2464  ", range 0 to ",
2465  data_size);
2466  float wgt = 1.f;
2467  float bio;
2468  if (weights) {
2469  wgt = weights[dataInd];
2470  }
2471  const float* scale_bias = reinterpret_cast<const float*>(
2472  &input[idx * fused_block_size + block_size]);
2473  bio = wgt * scale_bias[1];
2474  wgt = wgt * scale_bias[0];
2475  __m256 vbio = _mm256_set1_ps(bio);
2476  __m256 vwgt = _mm256_set1_ps(wgt);
2477  const uint8_t* ip = &input[idx * fused_block_size];
2478  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2479  ? (dataInd + prefdist_T0)
2480  : dataInd;
2481  const int64_t idx_pref_T0 = indices[next_T0];
2482  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2483  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2484  vop0 = _mm256_fmadd_ps(
2485  vwgt,
2486  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2487  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2488  _mm256_add_ps(vop0, vbio));
2489  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2490  vop8 = _mm256_fmadd_ps(
2491  vwgt,
2492  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2493  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2494  _mm256_add_ps(vop8, vbio));
2495  // skip unnecessary prefetch of (&ip_next_T0[8])
2496  vop16 = _mm256_fmadd_ps(
2497  vwgt,
2498  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2499  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2500  _mm256_add_ps(vop16, vbio));
2501  // skip unnecessary prefetch of (&ip_next_T0[16])
2502  vop24 = _mm256_fmadd_ps(
2503  vwgt,
2504  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2505  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2506  _mm256_add_ps(vop24, vbio));
2507  // skip unnecessary prefetch of (&ip_next_T0[24])
2508  vop32 = _mm256_fmadd_ps(
2509  vwgt,
2510  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2511  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2512  _mm256_add_ps(vop32, vbio));
2513  // skip unnecessary prefetch of (&ip_next_T0[32])
2514  vop40 = _mm256_fmadd_ps(
2515  vwgt,
2516  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2517  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2518  _mm256_add_ps(vop40, vbio));
2519  // skip unnecessary prefetch of (&ip_next_T0[40])
2520  vop48 = _mm256_fmadd_ps(
2521  vwgt,
2522  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2523  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2524  _mm256_add_ps(vop48, vbio));
2525  // skip unnecessary prefetch of (&ip_next_T0[48])
2526  vop56 = _mm256_fmadd_ps(
2527  vwgt,
2528  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2529  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2530  _mm256_add_ps(vop56, vbio));
2531  // skip unnecessary prefetch of (&ip_next_T0[56])
2532  }
2533  if (normalize_by_lengths == false) {
2534  _mm256_storeu_ps(&op[0], vop0);
2535  _mm256_storeu_ps(&op[8], vop8);
2536  _mm256_storeu_ps(&op[16], vop16);
2537  _mm256_storeu_ps(&op[24], vop24);
2538  _mm256_storeu_ps(&op[32], vop32);
2539  _mm256_storeu_ps(&op[40], vop40);
2540  _mm256_storeu_ps(&op[48], vop48);
2541  _mm256_storeu_ps(&op[56], vop56);
2542  } else if (lengths[rangeIndex]) {
2543  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2544  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2545  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2546  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2547  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2548  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2549  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2550  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2551  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2552  }
2553  }
2554  } else if (block_size == 32) {
2555  // unrolling 4 times
2556  int64_t dataInd = 0;
2557  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2558  float* op = &out[rangeIndex * block_size];
2559  __m256 vop0 = _mm256_setzero_ps();
2560  __m256 vop8 = _mm256_setzero_ps();
2561  __m256 vop16 = _mm256_setzero_ps();
2562  __m256 vop24 = _mm256_setzero_ps();
2563  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2564  ++dataInd) {
2565  const int64_t idx = indices[dataInd];
2566  CAFFE_ENFORCE(
2567  idx >= 0 && idx < data_size,
2568  "Index ",
2569  dataInd,
2570  " is out of bounds: ",
2571  idx,
2572  ", range 0 to ",
2573  data_size);
2574  float wgt = 1.f;
2575  float bio;
2576  if (weights) {
2577  wgt = weights[dataInd];
2578  }
2579  const float* scale_bias = reinterpret_cast<const float*>(
2580  &input[idx * fused_block_size + block_size]);
2581  bio = wgt * scale_bias[1];
2582  wgt = wgt * scale_bias[0];
2583  __m256 vbio = _mm256_set1_ps(bio);
2584  __m256 vwgt = _mm256_set1_ps(wgt);
2585  const uint8_t* ip = &input[idx * fused_block_size];
2586  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2587  ? (dataInd + prefdist_T0)
2588  : dataInd;
2589  const int64_t idx_pref_T0 = indices[next_T0];
2590  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2591  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2592  vop0 = _mm256_fmadd_ps(
2593  vwgt,
2594  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2595  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2596  _mm256_add_ps(vop0, vbio));
2597  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2598  vop8 = _mm256_fmadd_ps(
2599  vwgt,
2600  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2601  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2602  _mm256_add_ps(vop8, vbio));
2603  // skip unnecessary prefetch of (&ip_next_T0[8])
2604  vop16 = _mm256_fmadd_ps(
2605  vwgt,
2606  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2607  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2608  _mm256_add_ps(vop16, vbio));
2609  // skip unnecessary prefetch of (&ip_next_T0[16])
2610  vop24 = _mm256_fmadd_ps(
2611  vwgt,
2612  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2613  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2614  _mm256_add_ps(vop24, vbio));
2615  // skip unnecessary prefetch of (&ip_next_T0[24])
2616  }
2617  if (normalize_by_lengths == false) {
2618  _mm256_storeu_ps(&op[0], vop0);
2619  _mm256_storeu_ps(&op[8], vop8);
2620  _mm256_storeu_ps(&op[16], vop16);
2621  _mm256_storeu_ps(&op[24], vop24);
2622  } else if (lengths[rangeIndex]) {
2623  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2624  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2625  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2626  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2627  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2628  }
2629  }
2630  } else if (block_size == 16) {
2631  // unrolling 2 times
2632  int64_t dataInd = 0;
2633  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2634  float* op = &out[rangeIndex * block_size];
2635  __m256 vop0 = _mm256_setzero_ps();
2636  __m256 vop8 = _mm256_setzero_ps();
2637  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2638  ++dataInd) {
2639  const int64_t idx = indices[dataInd];
2640  CAFFE_ENFORCE(
2641  idx >= 0 && idx < data_size,
2642  "Index ",
2643  dataInd,
2644  " is out of bounds: ",
2645  idx,
2646  ", range 0 to ",
2647  data_size);
2648  float wgt = 1.f;
2649  float bio;
2650  if (weights) {
2651  wgt = weights[dataInd];
2652  }
2653  const float* scale_bias = reinterpret_cast<const float*>(
2654  &input[idx * fused_block_size + block_size]);
2655  bio = wgt * scale_bias[1];
2656  wgt = wgt * scale_bias[0];
2657  __m256 vbio = _mm256_set1_ps(bio);
2658  __m256 vwgt = _mm256_set1_ps(wgt);
2659  const uint8_t* ip = &input[idx * fused_block_size];
2660  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2661  ? (dataInd + prefdist_T0)
2662  : dataInd;
2663  const int64_t idx_pref_T0 = indices[next_T0];
2664  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2665  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2666  vop0 = _mm256_fmadd_ps(
2667  vwgt,
2668  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2669  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2670  _mm256_add_ps(vop0, vbio));
2671  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2672  vop8 = _mm256_fmadd_ps(
2673  vwgt,
2674  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2675  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2676  _mm256_add_ps(vop8, vbio));
2677  // skip unnecessary prefetch of (&ip_next_T0[8])
2678  }
2679  if (normalize_by_lengths == false) {
2680  _mm256_storeu_ps(&op[0], vop0);
2681  _mm256_storeu_ps(&op[8], vop8);
2682  } else if (lengths[rangeIndex]) {
2683  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2684  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2685  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2686  }
2687  }
2688  } else {
2689  // generic code
2690  int64_t dataInd = 0;
2691  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2692  float* op = &out[rangeIndex * block_size];
2693  TIndex j = 0;
2694  for (; j + 8 <= block_size; j += 8) {
2695  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2696  }
2697  for (; j < block_size; j++) {
2698  op[j] = 0.0f;
2699  }
2700  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2701  ++dataInd) {
2702  const int64_t idx = indices[dataInd];
2703  CAFFE_ENFORCE(
2704  idx >= 0 && idx < data_size,
2705  "Index ",
2706  dataInd,
2707  " is out of bounds: ",
2708  idx,
2709  ", range 0 to ",
2710  data_size);
2711  float wgt = 1.f;
2712  float bio;
2713  if (weights) {
2714  wgt = weights[dataInd];
2715  }
2716  const float* scale_bias = reinterpret_cast<const float*>(
2717  &input[idx * fused_block_size + block_size]);
2718  bio = wgt * scale_bias[1];
2719  wgt = wgt * scale_bias[0];
2720  __m256 vbio = _mm256_set1_ps(bio);
2721  __m256 vwgt = _mm256_set1_ps(wgt);
2722  const uint8_t* ip = &input[idx * fused_block_size];
2723  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2724  ? (dataInd + prefdist_T0)
2725  : dataInd;
2726  const int64_t idx_pref_T0 = indices[next_T0];
2727  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2728  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2729  j = 0;
2730  for (; j + 8 <= block_size; j += 8) {
2731  _mm256_storeu_ps(
2732  &op[j],
2733  _mm256_fmadd_ps(
2734  vwgt,
2735  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2736  reinterpret_cast<const __m128i*>(&ip[j])))),
2737  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2738  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2739  }
2740  for (; j < block_size; j++) {
2741  op[j] += wgt * ((float)ip[j]) + bio;
2742  }
2743  }
2744  if (normalize_by_lengths && lengths[rangeIndex]) {
2745  float len_inv = 1.0f / lengths[rangeIndex];
2746  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2747  j = 0;
2748  for (; j + 8 <= block_size; j += 8) {
2749  _mm256_storeu_ps(
2750  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2751  }
2752  for (; j < block_size; j++) {
2753  op[j] = len_inv * op[j];
2754  }
2755  }
2756  }
2757  }
2758 }
2759 
2760 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.