Caffe2 - C++ API
A deep learning, cross platform ML framework
embedding_lookup_avx2.cc
1 
17 
24 #include <immintrin.h>
25 #include "caffe2/core/common.h"
26 #include "caffe2/core/types.h"
27 
28 namespace caffe2 {
29 
30 void EmbeddingLookup_int32_t_float_float__avx2_fma(
31  const TIndex block_size,
32  const TIndex output_size,
33  const TIndex index_size,
34  const TIndex data_size,
35  const float* input,
36  const int32_t* indices,
37  const int* lengths,
38  const float* weights,
39  const float* scale_bias,
40  bool normalize_by_lengths,
41  float* out) {
42  const int32_t prefdist_T0 = 16;
43  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
44  if (block_size == 128) {
45  // unrolling 16 times
46  int32_t dataInd = 0;
47  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
48  float* op = &out[rangeIndex * block_size];
49  __m256 vop0 = _mm256_setzero_ps();
50  __m256 vop8 = _mm256_setzero_ps();
51  __m256 vop16 = _mm256_setzero_ps();
52  __m256 vop24 = _mm256_setzero_ps();
53  __m256 vop32 = _mm256_setzero_ps();
54  __m256 vop40 = _mm256_setzero_ps();
55  __m256 vop48 = _mm256_setzero_ps();
56  __m256 vop56 = _mm256_setzero_ps();
57  __m256 vop64 = _mm256_setzero_ps();
58  __m256 vop72 = _mm256_setzero_ps();
59  __m256 vop80 = _mm256_setzero_ps();
60  __m256 vop88 = _mm256_setzero_ps();
61  __m256 vop96 = _mm256_setzero_ps();
62  __m256 vop104 = _mm256_setzero_ps();
63  __m256 vop112 = _mm256_setzero_ps();
64  __m256 vop120 = _mm256_setzero_ps();
65  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
66  ++dataInd) {
67  const int32_t idx = indices[dataInd];
68  CAFFE_ENFORCE(
69  idx >= 0 && idx < data_size,
70  "Index ",
71  dataInd,
72  " is out of bounds: ",
73  idx,
74  ", range 0 to ",
75  data_size);
76  float wgt = 1.f;
77  if (weights) {
78  wgt = weights[dataInd];
79  }
80  __m256 vwgt = _mm256_set1_ps(wgt);
81  const float* ip = &input[idx * block_size];
82  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
83  ? (dataInd + prefdist_T0)
84  : dataInd;
85  const int32_t idx_pref_T0 = indices[next_T0];
86  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
87  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
88  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
89  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
90  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
91  // skip unecassery prefetch of (&ip_next_T0[8])
92  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
93  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
94  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
95  // skip unecassery prefetch of (&ip_next_T0[24])
96  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
97  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
98  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
99  // skip unecassery prefetch of (&ip_next_T0[40])
100  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
101  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
102  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
103  // skip unecassery prefetch of (&ip_next_T0[56])
104  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
105  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
106  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
107  // skip unecassery prefetch of (&ip_next_T0[72])
108  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
109  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
110  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
111  // skip unecassery prefetch of (&ip_next_T0[88])
112  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
113  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
114  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
115  // skip unecassery prefetch of (&ip_next_T0[104])
116  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
117  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
118  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
119  // skip unecassery prefetch of (&ip_next_T0[120])
120  }
121  if (normalize_by_lengths == false) {
122  _mm256_storeu_ps(&op[0], vop0);
123  _mm256_storeu_ps(&op[8], vop8);
124  _mm256_storeu_ps(&op[16], vop16);
125  _mm256_storeu_ps(&op[24], vop24);
126  _mm256_storeu_ps(&op[32], vop32);
127  _mm256_storeu_ps(&op[40], vop40);
128  _mm256_storeu_ps(&op[48], vop48);
129  _mm256_storeu_ps(&op[56], vop56);
130  _mm256_storeu_ps(&op[64], vop64);
131  _mm256_storeu_ps(&op[72], vop72);
132  _mm256_storeu_ps(&op[80], vop80);
133  _mm256_storeu_ps(&op[88], vop88);
134  _mm256_storeu_ps(&op[96], vop96);
135  _mm256_storeu_ps(&op[104], vop104);
136  _mm256_storeu_ps(&op[112], vop112);
137  _mm256_storeu_ps(&op[120], vop120);
138  } else if (lengths[rangeIndex]) {
139  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
140  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
141  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
142  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
143  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
144  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
145  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
146  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
147  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
148  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
149  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
150  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
151  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
152  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
153  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
154  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
155  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
156  }
157  }
158  } else if (block_size == 64) {
159  // unrolling 8 times
160  int32_t dataInd = 0;
161  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
162  float* op = &out[rangeIndex * block_size];
163  __m256 vop0 = _mm256_setzero_ps();
164  __m256 vop8 = _mm256_setzero_ps();
165  __m256 vop16 = _mm256_setzero_ps();
166  __m256 vop24 = _mm256_setzero_ps();
167  __m256 vop32 = _mm256_setzero_ps();
168  __m256 vop40 = _mm256_setzero_ps();
169  __m256 vop48 = _mm256_setzero_ps();
170  __m256 vop56 = _mm256_setzero_ps();
171  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
172  ++dataInd) {
173  const int32_t idx = indices[dataInd];
174  CAFFE_ENFORCE(
175  idx >= 0 && idx < data_size,
176  "Index ",
177  dataInd,
178  " is out of bounds: ",
179  idx,
180  ", range 0 to ",
181  data_size);
182  float wgt = 1.f;
183  if (weights) {
184  wgt = weights[dataInd];
185  }
186  __m256 vwgt = _mm256_set1_ps(wgt);
187  const float* ip = &input[idx * block_size];
188  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
189  ? (dataInd + prefdist_T0)
190  : dataInd;
191  const int32_t idx_pref_T0 = indices[next_T0];
192  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
193  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
194  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
195  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
196  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
197  // skip unecassery prefetch of (&ip_next_T0[8])
198  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
199  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
200  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
201  // skip unecassery prefetch of (&ip_next_T0[24])
202  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
203  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
204  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
205  // skip unecassery prefetch of (&ip_next_T0[40])
206  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
207  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
208  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
209  // skip unecassery prefetch of (&ip_next_T0[56])
210  }
211  if (normalize_by_lengths == false) {
212  _mm256_storeu_ps(&op[0], vop0);
213  _mm256_storeu_ps(&op[8], vop8);
214  _mm256_storeu_ps(&op[16], vop16);
215  _mm256_storeu_ps(&op[24], vop24);
216  _mm256_storeu_ps(&op[32], vop32);
217  _mm256_storeu_ps(&op[40], vop40);
218  _mm256_storeu_ps(&op[48], vop48);
219  _mm256_storeu_ps(&op[56], vop56);
220  } else if (lengths[rangeIndex]) {
221  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
222  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
223  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
224  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
225  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
226  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
227  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
228  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
229  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
230  }
231  }
232  } else if (block_size == 32) {
233  // unrolling 4 times
234  int32_t dataInd = 0;
235  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
236  float* op = &out[rangeIndex * block_size];
237  __m256 vop0 = _mm256_setzero_ps();
238  __m256 vop8 = _mm256_setzero_ps();
239  __m256 vop16 = _mm256_setzero_ps();
240  __m256 vop24 = _mm256_setzero_ps();
241  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
242  ++dataInd) {
243  const int32_t idx = indices[dataInd];
244  CAFFE_ENFORCE(
245  idx >= 0 && idx < data_size,
246  "Index ",
247  dataInd,
248  " is out of bounds: ",
249  idx,
250  ", range 0 to ",
251  data_size);
252  float wgt = 1.f;
253  if (weights) {
254  wgt = weights[dataInd];
255  }
256  __m256 vwgt = _mm256_set1_ps(wgt);
257  const float* ip = &input[idx * block_size];
258  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
259  ? (dataInd + prefdist_T0)
260  : dataInd;
261  const int32_t idx_pref_T0 = indices[next_T0];
262  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
263  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
264  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
265  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
266  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
267  // skip unecassery prefetch of (&ip_next_T0[8])
268  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
269  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
270  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
271  // skip unecassery prefetch of (&ip_next_T0[24])
272  }
273  if (normalize_by_lengths == false) {
274  _mm256_storeu_ps(&op[0], vop0);
275  _mm256_storeu_ps(&op[8], vop8);
276  _mm256_storeu_ps(&op[16], vop16);
277  _mm256_storeu_ps(&op[24], vop24);
278  } else if (lengths[rangeIndex]) {
279  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
280  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
281  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
282  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
283  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
284  }
285  }
286  } else if (block_size == 16) {
287  // unrolling 2 times
288  int32_t dataInd = 0;
289  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
290  float* op = &out[rangeIndex * block_size];
291  __m256 vop0 = _mm256_setzero_ps();
292  __m256 vop8 = _mm256_setzero_ps();
293  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
294  ++dataInd) {
295  const int32_t idx = indices[dataInd];
296  CAFFE_ENFORCE(
297  idx >= 0 && idx < data_size,
298  "Index ",
299  dataInd,
300  " is out of bounds: ",
301  idx,
302  ", range 0 to ",
303  data_size);
304  float wgt = 1.f;
305  if (weights) {
306  wgt = weights[dataInd];
307  }
308  __m256 vwgt = _mm256_set1_ps(wgt);
309  const float* ip = &input[idx * block_size];
310  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
311  ? (dataInd + prefdist_T0)
312  : dataInd;
313  const int32_t idx_pref_T0 = indices[next_T0];
314  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
315  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
316  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
317  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
318  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
319  // skip unecassery prefetch of (&ip_next_T0[8])
320  }
321  if (normalize_by_lengths == false) {
322  _mm256_storeu_ps(&op[0], vop0);
323  _mm256_storeu_ps(&op[8], vop8);
324  } else if (lengths[rangeIndex]) {
325  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
326  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
327  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
328  }
329  }
330  } else {
331  // generic code
332  int32_t dataInd = 0;
333  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
334  float* op = &out[rangeIndex * block_size];
335  TIndex j = 0;
336  for (; j + 8 <= block_size; j += 8) {
337  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
338  }
339  for (; j < block_size; j++) {
340  op[j] = 0.0f;
341  }
342  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
343  ++dataInd) {
344  const int32_t idx = indices[dataInd];
345  CAFFE_ENFORCE(
346  idx >= 0 && idx < data_size,
347  "Index ",
348  dataInd,
349  " is out of bounds: ",
350  idx,
351  ", range 0 to ",
352  data_size);
353  float wgt = 1.f;
354  if (weights) {
355  wgt = weights[dataInd];
356  }
357  __m256 vwgt = _mm256_set1_ps(wgt);
358  const float* ip = &input[idx * block_size];
359  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
360  ? (dataInd + prefdist_T0)
361  : dataInd;
362  const int32_t idx_pref_T0 = indices[next_T0];
363  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
364  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
365  j = 0;
366  for (; j + 8 <= block_size; j += 8) {
367  _mm256_storeu_ps(
368  &op[j],
369  _mm256_fmadd_ps(
370  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
371  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
372  }
373  for (; j < block_size; j++) {
374  op[j] += wgt * ip[j];
375  }
376  }
377  if (normalize_by_lengths && lengths[rangeIndex]) {
378  float len_inv = 1.0f / lengths[rangeIndex];
379  __m256 vlen_inv = _mm256_set1_ps(len_inv);
380  j = 0;
381  for (; j + 8 <= block_size; j += 8) {
382  _mm256_storeu_ps(
383  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
384  }
385  for (; j < block_size; j++) {
386  op[j] = len_inv * op[j];
387  }
388  }
389  }
390  }
391 }
392 
393 void EmbeddingLookup_int64_t_float_float__avx2_fma(
394  const TIndex block_size,
395  const TIndex output_size,
396  const TIndex index_size,
397  const TIndex data_size,
398  const float* input,
399  const int64_t* indices,
400  const int* lengths,
401  const float* weights,
402  const float* scale_bias,
403  bool normalize_by_lengths,
404  float* out) {
405  const int64_t prefdist_T0 = 16;
406  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
407  if (block_size == 128) {
408  // unrolling 16 times
409  int64_t dataInd = 0;
410  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
411  float* op = &out[rangeIndex * block_size];
412  __m256 vop0 = _mm256_setzero_ps();
413  __m256 vop8 = _mm256_setzero_ps();
414  __m256 vop16 = _mm256_setzero_ps();
415  __m256 vop24 = _mm256_setzero_ps();
416  __m256 vop32 = _mm256_setzero_ps();
417  __m256 vop40 = _mm256_setzero_ps();
418  __m256 vop48 = _mm256_setzero_ps();
419  __m256 vop56 = _mm256_setzero_ps();
420  __m256 vop64 = _mm256_setzero_ps();
421  __m256 vop72 = _mm256_setzero_ps();
422  __m256 vop80 = _mm256_setzero_ps();
423  __m256 vop88 = _mm256_setzero_ps();
424  __m256 vop96 = _mm256_setzero_ps();
425  __m256 vop104 = _mm256_setzero_ps();
426  __m256 vop112 = _mm256_setzero_ps();
427  __m256 vop120 = _mm256_setzero_ps();
428  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
429  ++dataInd) {
430  const int64_t idx = indices[dataInd];
431  CAFFE_ENFORCE(
432  idx >= 0 && idx < data_size,
433  "Index ",
434  dataInd,
435  " is out of bounds: ",
436  idx,
437  ", range 0 to ",
438  data_size);
439  float wgt = 1.f;
440  if (weights) {
441  wgt = weights[dataInd];
442  }
443  __m256 vwgt = _mm256_set1_ps(wgt);
444  const float* ip = &input[idx * block_size];
445  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
446  ? (dataInd + prefdist_T0)
447  : dataInd;
448  const int64_t idx_pref_T0 = indices[next_T0];
449  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
450  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
451  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
452  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
453  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
454  // skip unecassery prefetch of (&ip_next_T0[8])
455  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
456  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
457  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
458  // skip unecassery prefetch of (&ip_next_T0[24])
459  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
460  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
461  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
462  // skip unecassery prefetch of (&ip_next_T0[40])
463  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
464  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
465  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
466  // skip unecassery prefetch of (&ip_next_T0[56])
467  vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
468  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
469  vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
470  // skip unecassery prefetch of (&ip_next_T0[72])
471  vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
472  _mm_prefetch((&ip_next_T0[80]), _MM_HINT_T0);
473  vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
474  // skip unecassery prefetch of (&ip_next_T0[88])
475  vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
476  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
477  vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
478  // skip unecassery prefetch of (&ip_next_T0[104])
479  vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
480  _mm_prefetch((&ip_next_T0[112]), _MM_HINT_T0);
481  vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
482  // skip unecassery prefetch of (&ip_next_T0[120])
483  }
484  if (normalize_by_lengths == false) {
485  _mm256_storeu_ps(&op[0], vop0);
486  _mm256_storeu_ps(&op[8], vop8);
487  _mm256_storeu_ps(&op[16], vop16);
488  _mm256_storeu_ps(&op[24], vop24);
489  _mm256_storeu_ps(&op[32], vop32);
490  _mm256_storeu_ps(&op[40], vop40);
491  _mm256_storeu_ps(&op[48], vop48);
492  _mm256_storeu_ps(&op[56], vop56);
493  _mm256_storeu_ps(&op[64], vop64);
494  _mm256_storeu_ps(&op[72], vop72);
495  _mm256_storeu_ps(&op[80], vop80);
496  _mm256_storeu_ps(&op[88], vop88);
497  _mm256_storeu_ps(&op[96], vop96);
498  _mm256_storeu_ps(&op[104], vop104);
499  _mm256_storeu_ps(&op[112], vop112);
500  _mm256_storeu_ps(&op[120], vop120);
501  } else if (lengths[rangeIndex]) {
502  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
503  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
504  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
505  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
506  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
507  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
508  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
509  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
510  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
511  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
512  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
513  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
514  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
515  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
516  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
517  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
518  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
519  }
520  }
521  } else if (block_size == 64) {
522  // unrolling 8 times
523  int64_t dataInd = 0;
524  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
525  float* op = &out[rangeIndex * block_size];
526  __m256 vop0 = _mm256_setzero_ps();
527  __m256 vop8 = _mm256_setzero_ps();
528  __m256 vop16 = _mm256_setzero_ps();
529  __m256 vop24 = _mm256_setzero_ps();
530  __m256 vop32 = _mm256_setzero_ps();
531  __m256 vop40 = _mm256_setzero_ps();
532  __m256 vop48 = _mm256_setzero_ps();
533  __m256 vop56 = _mm256_setzero_ps();
534  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
535  ++dataInd) {
536  const int64_t idx = indices[dataInd];
537  CAFFE_ENFORCE(
538  idx >= 0 && idx < data_size,
539  "Index ",
540  dataInd,
541  " is out of bounds: ",
542  idx,
543  ", range 0 to ",
544  data_size);
545  float wgt = 1.f;
546  if (weights) {
547  wgt = weights[dataInd];
548  }
549  __m256 vwgt = _mm256_set1_ps(wgt);
550  const float* ip = &input[idx * block_size];
551  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
552  ? (dataInd + prefdist_T0)
553  : dataInd;
554  const int64_t idx_pref_T0 = indices[next_T0];
555  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
556  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
557  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
558  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
559  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
560  // skip unecassery prefetch of (&ip_next_T0[8])
561  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
562  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
563  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
564  // skip unecassery prefetch of (&ip_next_T0[24])
565  vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
566  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
567  vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
568  // skip unecassery prefetch of (&ip_next_T0[40])
569  vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
570  _mm_prefetch((&ip_next_T0[48]), _MM_HINT_T0);
571  vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
572  // skip unecassery prefetch of (&ip_next_T0[56])
573  }
574  if (normalize_by_lengths == false) {
575  _mm256_storeu_ps(&op[0], vop0);
576  _mm256_storeu_ps(&op[8], vop8);
577  _mm256_storeu_ps(&op[16], vop16);
578  _mm256_storeu_ps(&op[24], vop24);
579  _mm256_storeu_ps(&op[32], vop32);
580  _mm256_storeu_ps(&op[40], vop40);
581  _mm256_storeu_ps(&op[48], vop48);
582  _mm256_storeu_ps(&op[56], vop56);
583  } else if (lengths[rangeIndex]) {
584  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
585  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
586  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
587  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
588  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
589  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
590  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
591  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
592  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
593  }
594  }
595  } else if (block_size == 32) {
596  // unrolling 4 times
597  int64_t dataInd = 0;
598  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
599  float* op = &out[rangeIndex * block_size];
600  __m256 vop0 = _mm256_setzero_ps();
601  __m256 vop8 = _mm256_setzero_ps();
602  __m256 vop16 = _mm256_setzero_ps();
603  __m256 vop24 = _mm256_setzero_ps();
604  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
605  ++dataInd) {
606  const int64_t idx = indices[dataInd];
607  CAFFE_ENFORCE(
608  idx >= 0 && idx < data_size,
609  "Index ",
610  dataInd,
611  " is out of bounds: ",
612  idx,
613  ", range 0 to ",
614  data_size);
615  float wgt = 1.f;
616  if (weights) {
617  wgt = weights[dataInd];
618  }
619  __m256 vwgt = _mm256_set1_ps(wgt);
620  const float* ip = &input[idx * block_size];
621  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
622  ? (dataInd + prefdist_T0)
623  : dataInd;
624  const int64_t idx_pref_T0 = indices[next_T0];
625  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
626  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
627  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
628  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
629  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
630  // skip unecassery prefetch of (&ip_next_T0[8])
631  vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
632  _mm_prefetch((&ip_next_T0[16]), _MM_HINT_T0);
633  vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
634  // skip unecassery prefetch of (&ip_next_T0[24])
635  }
636  if (normalize_by_lengths == false) {
637  _mm256_storeu_ps(&op[0], vop0);
638  _mm256_storeu_ps(&op[8], vop8);
639  _mm256_storeu_ps(&op[16], vop16);
640  _mm256_storeu_ps(&op[24], vop24);
641  } else if (lengths[rangeIndex]) {
642  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
643  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
644  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
645  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
646  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
647  }
648  }
649  } else if (block_size == 16) {
650  // unrolling 2 times
651  int64_t dataInd = 0;
652  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
653  float* op = &out[rangeIndex * block_size];
654  __m256 vop0 = _mm256_setzero_ps();
655  __m256 vop8 = _mm256_setzero_ps();
656  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
657  ++dataInd) {
658  const int64_t idx = indices[dataInd];
659  CAFFE_ENFORCE(
660  idx >= 0 && idx < data_size,
661  "Index ",
662  dataInd,
663  " is out of bounds: ",
664  idx,
665  ", range 0 to ",
666  data_size);
667  float wgt = 1.f;
668  if (weights) {
669  wgt = weights[dataInd];
670  }
671  __m256 vwgt = _mm256_set1_ps(wgt);
672  const float* ip = &input[idx * block_size];
673  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
674  ? (dataInd + prefdist_T0)
675  : dataInd;
676  const int64_t idx_pref_T0 = indices[next_T0];
677  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
678  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
679  vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
680  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
681  vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
682  // skip unecassery prefetch of (&ip_next_T0[8])
683  }
684  if (normalize_by_lengths == false) {
685  _mm256_storeu_ps(&op[0], vop0);
686  _mm256_storeu_ps(&op[8], vop8);
687  } else if (lengths[rangeIndex]) {
688  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
689  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
690  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
691  }
692  }
693  } else {
694  // generic code
695  int64_t dataInd = 0;
696  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
697  float* op = &out[rangeIndex * block_size];
698  TIndex j = 0;
699  for (; j + 8 <= block_size; j += 8) {
700  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
701  }
702  for (; j < block_size; j++) {
703  op[j] = 0.0f;
704  }
705  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
706  ++dataInd) {
707  const int64_t idx = indices[dataInd];
708  CAFFE_ENFORCE(
709  idx >= 0 && idx < data_size,
710  "Index ",
711  dataInd,
712  " is out of bounds: ",
713  idx,
714  ", range 0 to ",
715  data_size);
716  float wgt = 1.f;
717  if (weights) {
718  wgt = weights[dataInd];
719  }
720  __m256 vwgt = _mm256_set1_ps(wgt);
721  const float* ip = &input[idx * block_size];
722  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
723  ? (dataInd + prefdist_T0)
724  : dataInd;
725  const int64_t idx_pref_T0 = indices[next_T0];
726  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
727  const float* ip_next_T0 = &input[idx_pref_T0 * block_size];
728  j = 0;
729  for (; j + 8 <= block_size; j += 8) {
730  _mm256_storeu_ps(
731  &op[j],
732  _mm256_fmadd_ps(
733  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
734  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
735  }
736  for (; j < block_size; j++) {
737  op[j] += wgt * ip[j];
738  }
739  }
740  if (normalize_by_lengths && lengths[rangeIndex]) {
741  float len_inv = 1.0f / lengths[rangeIndex];
742  __m256 vlen_inv = _mm256_set1_ps(len_inv);
743  j = 0;
744  for (; j + 8 <= block_size; j += 8) {
745  _mm256_storeu_ps(
746  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
747  }
748  for (; j < block_size; j++) {
749  op[j] = len_inv * op[j];
750  }
751  }
752  }
753  }
754 }
755 
756 void EmbeddingLookup_int32_t_float16_float__avx2_fma(
757  const TIndex block_size,
758  const TIndex output_size,
759  const TIndex index_size,
760  const TIndex data_size,
761  const float16* input,
762  const int32_t* indices,
763  const int* lengths,
764  const float* weights,
765  const float* scale_bias,
766  bool normalize_by_lengths,
767  float* out) {
768  const int32_t prefdist_T0 = 16;
769  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
770  if (block_size == 128) {
771  // unrolling 16 times
772  int32_t dataInd = 0;
773  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
774  float* op = &out[rangeIndex * block_size];
775  __m256 vop0 = _mm256_setzero_ps();
776  __m256 vop8 = _mm256_setzero_ps();
777  __m256 vop16 = _mm256_setzero_ps();
778  __m256 vop24 = _mm256_setzero_ps();
779  __m256 vop32 = _mm256_setzero_ps();
780  __m256 vop40 = _mm256_setzero_ps();
781  __m256 vop48 = _mm256_setzero_ps();
782  __m256 vop56 = _mm256_setzero_ps();
783  __m256 vop64 = _mm256_setzero_ps();
784  __m256 vop72 = _mm256_setzero_ps();
785  __m256 vop80 = _mm256_setzero_ps();
786  __m256 vop88 = _mm256_setzero_ps();
787  __m256 vop96 = _mm256_setzero_ps();
788  __m256 vop104 = _mm256_setzero_ps();
789  __m256 vop112 = _mm256_setzero_ps();
790  __m256 vop120 = _mm256_setzero_ps();
791  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
792  ++dataInd) {
793  const int32_t idx = indices[dataInd];
794  CAFFE_ENFORCE(
795  idx >= 0 && idx < data_size,
796  "Index ",
797  dataInd,
798  " is out of bounds: ",
799  idx,
800  ", range 0 to ",
801  data_size);
802  float wgt = 1.f;
803  if (weights) {
804  wgt = weights[dataInd];
805  }
806  __m256 vwgt = _mm256_set1_ps(wgt);
807  const float16* ip = &input[idx * block_size];
808  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
809  ? (dataInd + prefdist_T0)
810  : dataInd;
811  const int32_t idx_pref_T0 = indices[next_T0];
812  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
813  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
814  vop0 = _mm256_fmadd_ps(
815  vwgt,
816  _mm256_cvtph_ps(
817  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
818  vop0);
819  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
820  vop8 = _mm256_fmadd_ps(
821  vwgt,
822  _mm256_cvtph_ps(
823  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
824  vop8);
825  // skip unecassery prefetch of (&ip_next_T0[8])
826  vop16 = _mm256_fmadd_ps(
827  vwgt,
828  _mm256_cvtph_ps(
829  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
830  vop16);
831  // skip unecassery prefetch of (&ip_next_T0[16])
832  vop24 = _mm256_fmadd_ps(
833  vwgt,
834  _mm256_cvtph_ps(
835  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
836  vop24);
837  // skip unecassery prefetch of (&ip_next_T0[24])
838  vop32 = _mm256_fmadd_ps(
839  vwgt,
840  _mm256_cvtph_ps(
841  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
842  vop32);
843  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
844  vop40 = _mm256_fmadd_ps(
845  vwgt,
846  _mm256_cvtph_ps(
847  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
848  vop40);
849  // skip unecassery prefetch of (&ip_next_T0[40])
850  vop48 = _mm256_fmadd_ps(
851  vwgt,
852  _mm256_cvtph_ps(
853  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
854  vop48);
855  // skip unecassery prefetch of (&ip_next_T0[48])
856  vop56 = _mm256_fmadd_ps(
857  vwgt,
858  _mm256_cvtph_ps(
859  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
860  vop56);
861  // skip unecassery prefetch of (&ip_next_T0[56])
862  vop64 = _mm256_fmadd_ps(
863  vwgt,
864  _mm256_cvtph_ps(
865  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
866  vop64);
867  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
868  vop72 = _mm256_fmadd_ps(
869  vwgt,
870  _mm256_cvtph_ps(
871  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
872  vop72);
873  // skip unecassery prefetch of (&ip_next_T0[72])
874  vop80 = _mm256_fmadd_ps(
875  vwgt,
876  _mm256_cvtph_ps(
877  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
878  vop80);
879  // skip unecassery prefetch of (&ip_next_T0[80])
880  vop88 = _mm256_fmadd_ps(
881  vwgt,
882  _mm256_cvtph_ps(
883  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
884  vop88);
885  // skip unecassery prefetch of (&ip_next_T0[88])
886  vop96 = _mm256_fmadd_ps(
887  vwgt,
888  _mm256_cvtph_ps(
889  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
890  vop96);
891  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
892  vop104 = _mm256_fmadd_ps(
893  vwgt,
894  _mm256_cvtph_ps(
895  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
896  vop104);
897  // skip unecassery prefetch of (&ip_next_T0[104])
898  vop112 = _mm256_fmadd_ps(
899  vwgt,
900  _mm256_cvtph_ps(
901  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
902  vop112);
903  // skip unecassery prefetch of (&ip_next_T0[112])
904  vop120 = _mm256_fmadd_ps(
905  vwgt,
906  _mm256_cvtph_ps(
907  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
908  vop120);
909  // skip unecassery prefetch of (&ip_next_T0[120])
910  }
911  if (normalize_by_lengths == false) {
912  _mm256_storeu_ps(&op[0], vop0);
913  _mm256_storeu_ps(&op[8], vop8);
914  _mm256_storeu_ps(&op[16], vop16);
915  _mm256_storeu_ps(&op[24], vop24);
916  _mm256_storeu_ps(&op[32], vop32);
917  _mm256_storeu_ps(&op[40], vop40);
918  _mm256_storeu_ps(&op[48], vop48);
919  _mm256_storeu_ps(&op[56], vop56);
920  _mm256_storeu_ps(&op[64], vop64);
921  _mm256_storeu_ps(&op[72], vop72);
922  _mm256_storeu_ps(&op[80], vop80);
923  _mm256_storeu_ps(&op[88], vop88);
924  _mm256_storeu_ps(&op[96], vop96);
925  _mm256_storeu_ps(&op[104], vop104);
926  _mm256_storeu_ps(&op[112], vop112);
927  _mm256_storeu_ps(&op[120], vop120);
928  } else if (lengths[rangeIndex]) {
929  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
930  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
931  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
932  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
933  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
934  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
935  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
936  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
937  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
938  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
939  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
940  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
941  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
942  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
943  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
944  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
945  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
946  }
947  }
948  } else if (block_size == 64) {
949  // unrolling 8 times
950  int32_t dataInd = 0;
951  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
952  float* op = &out[rangeIndex * block_size];
953  __m256 vop0 = _mm256_setzero_ps();
954  __m256 vop8 = _mm256_setzero_ps();
955  __m256 vop16 = _mm256_setzero_ps();
956  __m256 vop24 = _mm256_setzero_ps();
957  __m256 vop32 = _mm256_setzero_ps();
958  __m256 vop40 = _mm256_setzero_ps();
959  __m256 vop48 = _mm256_setzero_ps();
960  __m256 vop56 = _mm256_setzero_ps();
961  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
962  ++dataInd) {
963  const int32_t idx = indices[dataInd];
964  CAFFE_ENFORCE(
965  idx >= 0 && idx < data_size,
966  "Index ",
967  dataInd,
968  " is out of bounds: ",
969  idx,
970  ", range 0 to ",
971  data_size);
972  float wgt = 1.f;
973  if (weights) {
974  wgt = weights[dataInd];
975  }
976  __m256 vwgt = _mm256_set1_ps(wgt);
977  const float16* ip = &input[idx * block_size];
978  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
979  ? (dataInd + prefdist_T0)
980  : dataInd;
981  const int32_t idx_pref_T0 = indices[next_T0];
982  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
983  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
984  vop0 = _mm256_fmadd_ps(
985  vwgt,
986  _mm256_cvtph_ps(
987  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
988  vop0);
989  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
990  vop8 = _mm256_fmadd_ps(
991  vwgt,
992  _mm256_cvtph_ps(
993  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
994  vop8);
995  // skip unecassery prefetch of (&ip_next_T0[8])
996  vop16 = _mm256_fmadd_ps(
997  vwgt,
998  _mm256_cvtph_ps(
999  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1000  vop16);
1001  // skip unecassery prefetch of (&ip_next_T0[16])
1002  vop24 = _mm256_fmadd_ps(
1003  vwgt,
1004  _mm256_cvtph_ps(
1005  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1006  vop24);
1007  // skip unecassery prefetch of (&ip_next_T0[24])
1008  vop32 = _mm256_fmadd_ps(
1009  vwgt,
1010  _mm256_cvtph_ps(
1011  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1012  vop32);
1013  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1014  vop40 = _mm256_fmadd_ps(
1015  vwgt,
1016  _mm256_cvtph_ps(
1017  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1018  vop40);
1019  // skip unecassery prefetch of (&ip_next_T0[40])
1020  vop48 = _mm256_fmadd_ps(
1021  vwgt,
1022  _mm256_cvtph_ps(
1023  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1024  vop48);
1025  // skip unecassery prefetch of (&ip_next_T0[48])
1026  vop56 = _mm256_fmadd_ps(
1027  vwgt,
1028  _mm256_cvtph_ps(
1029  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1030  vop56);
1031  // skip unecassery prefetch of (&ip_next_T0[56])
1032  }
1033  if (normalize_by_lengths == false) {
1034  _mm256_storeu_ps(&op[0], vop0);
1035  _mm256_storeu_ps(&op[8], vop8);
1036  _mm256_storeu_ps(&op[16], vop16);
1037  _mm256_storeu_ps(&op[24], vop24);
1038  _mm256_storeu_ps(&op[32], vop32);
1039  _mm256_storeu_ps(&op[40], vop40);
1040  _mm256_storeu_ps(&op[48], vop48);
1041  _mm256_storeu_ps(&op[56], vop56);
1042  } else if (lengths[rangeIndex]) {
1043  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1044  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1045  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1046  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1047  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1048  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1049  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1050  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1051  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1052  }
1053  }
1054  } else if (block_size == 32) {
1055  // unrolling 4 times
1056  int32_t dataInd = 0;
1057  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1058  float* op = &out[rangeIndex * block_size];
1059  __m256 vop0 = _mm256_setzero_ps();
1060  __m256 vop8 = _mm256_setzero_ps();
1061  __m256 vop16 = _mm256_setzero_ps();
1062  __m256 vop24 = _mm256_setzero_ps();
1063  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1064  ++dataInd) {
1065  const int32_t idx = indices[dataInd];
1066  CAFFE_ENFORCE(
1067  idx >= 0 && idx < data_size,
1068  "Index ",
1069  dataInd,
1070  " is out of bounds: ",
1071  idx,
1072  ", range 0 to ",
1073  data_size);
1074  float wgt = 1.f;
1075  if (weights) {
1076  wgt = weights[dataInd];
1077  }
1078  __m256 vwgt = _mm256_set1_ps(wgt);
1079  const float16* ip = &input[idx * block_size];
1080  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1081  ? (dataInd + prefdist_T0)
1082  : dataInd;
1083  const int32_t idx_pref_T0 = indices[next_T0];
1084  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1085  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1086  vop0 = _mm256_fmadd_ps(
1087  vwgt,
1088  _mm256_cvtph_ps(
1089  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1090  vop0);
1091  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1092  vop8 = _mm256_fmadd_ps(
1093  vwgt,
1094  _mm256_cvtph_ps(
1095  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1096  vop8);
1097  // skip unecassery prefetch of (&ip_next_T0[8])
1098  vop16 = _mm256_fmadd_ps(
1099  vwgt,
1100  _mm256_cvtph_ps(
1101  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1102  vop16);
1103  // skip unecassery prefetch of (&ip_next_T0[16])
1104  vop24 = _mm256_fmadd_ps(
1105  vwgt,
1106  _mm256_cvtph_ps(
1107  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1108  vop24);
1109  // skip unecassery prefetch of (&ip_next_T0[24])
1110  }
1111  if (normalize_by_lengths == false) {
1112  _mm256_storeu_ps(&op[0], vop0);
1113  _mm256_storeu_ps(&op[8], vop8);
1114  _mm256_storeu_ps(&op[16], vop16);
1115  _mm256_storeu_ps(&op[24], vop24);
1116  } else if (lengths[rangeIndex]) {
1117  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1118  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1119  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1120  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1121  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1122  }
1123  }
1124  } else if (block_size == 16) {
1125  // unrolling 2 times
1126  int32_t dataInd = 0;
1127  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1128  float* op = &out[rangeIndex * block_size];
1129  __m256 vop0 = _mm256_setzero_ps();
1130  __m256 vop8 = _mm256_setzero_ps();
1131  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1132  ++dataInd) {
1133  const int32_t idx = indices[dataInd];
1134  CAFFE_ENFORCE(
1135  idx >= 0 && idx < data_size,
1136  "Index ",
1137  dataInd,
1138  " is out of bounds: ",
1139  idx,
1140  ", range 0 to ",
1141  data_size);
1142  float wgt = 1.f;
1143  if (weights) {
1144  wgt = weights[dataInd];
1145  }
1146  __m256 vwgt = _mm256_set1_ps(wgt);
1147  const float16* ip = &input[idx * block_size];
1148  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1149  ? (dataInd + prefdist_T0)
1150  : dataInd;
1151  const int32_t idx_pref_T0 = indices[next_T0];
1152  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1153  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1154  vop0 = _mm256_fmadd_ps(
1155  vwgt,
1156  _mm256_cvtph_ps(
1157  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1158  vop0);
1159  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1160  vop8 = _mm256_fmadd_ps(
1161  vwgt,
1162  _mm256_cvtph_ps(
1163  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1164  vop8);
1165  // skip unecassery prefetch of (&ip_next_T0[8])
1166  }
1167  if (normalize_by_lengths == false) {
1168  _mm256_storeu_ps(&op[0], vop0);
1169  _mm256_storeu_ps(&op[8], vop8);
1170  } else if (lengths[rangeIndex]) {
1171  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1172  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1173  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1174  }
1175  }
1176  } else {
1177  // generic code
1178  int32_t dataInd = 0;
1179  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1180  float* op = &out[rangeIndex * block_size];
1181  TIndex j = 0;
1182  for (; j + 8 <= block_size; j += 8) {
1183  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1184  }
1185  for (; j < block_size; j++) {
1186  op[j] = 0.0f;
1187  }
1188  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1189  ++dataInd) {
1190  const int32_t idx = indices[dataInd];
1191  CAFFE_ENFORCE(
1192  idx >= 0 && idx < data_size,
1193  "Index ",
1194  dataInd,
1195  " is out of bounds: ",
1196  idx,
1197  ", range 0 to ",
1198  data_size);
1199  float wgt = 1.f;
1200  if (weights) {
1201  wgt = weights[dataInd];
1202  }
1203  __m256 vwgt = _mm256_set1_ps(wgt);
1204  const float16* ip = &input[idx * block_size];
1205  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1206  ? (dataInd + prefdist_T0)
1207  : dataInd;
1208  const int32_t idx_pref_T0 = indices[next_T0];
1209  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1210  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1211  j = 0;
1212  for (; j + 8 <= block_size; j += 8) {
1213  _mm256_storeu_ps(
1214  &op[j],
1215  _mm256_fmadd_ps(
1216  vwgt,
1217  _mm256_cvtph_ps(_mm_loadu_si128(
1218  reinterpret_cast<const __m128i*>(&ip[j]))),
1219  _mm256_loadu_ps(&op[j])));
1220  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1221  }
1222  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1223  for (; j < block_size; j++) {
1224  vtmp1[0] = ip[j];
1225  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1226  op[j] += wgt * ((float*)(&vtmp2))[0];
1227  }
1228  }
1229  if (normalize_by_lengths && lengths[rangeIndex]) {
1230  float len_inv = 1.0f / lengths[rangeIndex];
1231  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1232  j = 0;
1233  for (; j + 8 <= block_size; j += 8) {
1234  _mm256_storeu_ps(
1235  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1236  }
1237  for (; j < block_size; j++) {
1238  op[j] = len_inv * op[j];
1239  }
1240  }
1241  }
1242  }
1243 }
1244 
1245 void EmbeddingLookup_int64_t_float16_float__avx2_fma(
1246  const TIndex block_size,
1247  const TIndex output_size,
1248  const TIndex index_size,
1249  const TIndex data_size,
1250  const float16* input,
1251  const int64_t* indices,
1252  const int* lengths,
1253  const float* weights,
1254  const float* scale_bias,
1255  bool normalize_by_lengths,
1256  float* out) {
1257  const int64_t prefdist_T0 = 16;
1258  CAFFE_ENFORCE(scale_bias == nullptr, "scale_bias must be nullptr");
1259  if (block_size == 128) {
1260  // unrolling 16 times
1261  int64_t dataInd = 0;
1262  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1263  float* op = &out[rangeIndex * block_size];
1264  __m256 vop0 = _mm256_setzero_ps();
1265  __m256 vop8 = _mm256_setzero_ps();
1266  __m256 vop16 = _mm256_setzero_ps();
1267  __m256 vop24 = _mm256_setzero_ps();
1268  __m256 vop32 = _mm256_setzero_ps();
1269  __m256 vop40 = _mm256_setzero_ps();
1270  __m256 vop48 = _mm256_setzero_ps();
1271  __m256 vop56 = _mm256_setzero_ps();
1272  __m256 vop64 = _mm256_setzero_ps();
1273  __m256 vop72 = _mm256_setzero_ps();
1274  __m256 vop80 = _mm256_setzero_ps();
1275  __m256 vop88 = _mm256_setzero_ps();
1276  __m256 vop96 = _mm256_setzero_ps();
1277  __m256 vop104 = _mm256_setzero_ps();
1278  __m256 vop112 = _mm256_setzero_ps();
1279  __m256 vop120 = _mm256_setzero_ps();
1280  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1281  ++dataInd) {
1282  const int64_t idx = indices[dataInd];
1283  CAFFE_ENFORCE(
1284  idx >= 0 && idx < data_size,
1285  "Index ",
1286  dataInd,
1287  " is out of bounds: ",
1288  idx,
1289  ", range 0 to ",
1290  data_size);
1291  float wgt = 1.f;
1292  if (weights) {
1293  wgt = weights[dataInd];
1294  }
1295  __m256 vwgt = _mm256_set1_ps(wgt);
1296  const float16* ip = &input[idx * block_size];
1297  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1298  ? (dataInd + prefdist_T0)
1299  : dataInd;
1300  const int64_t idx_pref_T0 = indices[next_T0];
1301  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1302  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1303  vop0 = _mm256_fmadd_ps(
1304  vwgt,
1305  _mm256_cvtph_ps(
1306  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1307  vop0);
1308  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1309  vop8 = _mm256_fmadd_ps(
1310  vwgt,
1311  _mm256_cvtph_ps(
1312  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1313  vop8);
1314  // skip unecassery prefetch of (&ip_next_T0[8])
1315  vop16 = _mm256_fmadd_ps(
1316  vwgt,
1317  _mm256_cvtph_ps(
1318  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1319  vop16);
1320  // skip unecassery prefetch of (&ip_next_T0[16])
1321  vop24 = _mm256_fmadd_ps(
1322  vwgt,
1323  _mm256_cvtph_ps(
1324  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1325  vop24);
1326  // skip unecassery prefetch of (&ip_next_T0[24])
1327  vop32 = _mm256_fmadd_ps(
1328  vwgt,
1329  _mm256_cvtph_ps(
1330  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1331  vop32);
1332  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1333  vop40 = _mm256_fmadd_ps(
1334  vwgt,
1335  _mm256_cvtph_ps(
1336  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1337  vop40);
1338  // skip unecassery prefetch of (&ip_next_T0[40])
1339  vop48 = _mm256_fmadd_ps(
1340  vwgt,
1341  _mm256_cvtph_ps(
1342  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1343  vop48);
1344  // skip unecassery prefetch of (&ip_next_T0[48])
1345  vop56 = _mm256_fmadd_ps(
1346  vwgt,
1347  _mm256_cvtph_ps(
1348  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1349  vop56);
1350  // skip unecassery prefetch of (&ip_next_T0[56])
1351  vop64 = _mm256_fmadd_ps(
1352  vwgt,
1353  _mm256_cvtph_ps(
1354  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1355  vop64);
1356  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1357  vop72 = _mm256_fmadd_ps(
1358  vwgt,
1359  _mm256_cvtph_ps(
1360  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1361  vop72);
1362  // skip unecassery prefetch of (&ip_next_T0[72])
1363  vop80 = _mm256_fmadd_ps(
1364  vwgt,
1365  _mm256_cvtph_ps(
1366  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1367  vop80);
1368  // skip unecassery prefetch of (&ip_next_T0[80])
1369  vop88 = _mm256_fmadd_ps(
1370  vwgt,
1371  _mm256_cvtph_ps(
1372  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1373  vop88);
1374  // skip unecassery prefetch of (&ip_next_T0[88])
1375  vop96 = _mm256_fmadd_ps(
1376  vwgt,
1377  _mm256_cvtph_ps(
1378  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1379  vop96);
1380  _mm_prefetch((&ip_next_T0[96]), _MM_HINT_T0);
1381  vop104 = _mm256_fmadd_ps(
1382  vwgt,
1383  _mm256_cvtph_ps(
1384  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1385  vop104);
1386  // skip unecassery prefetch of (&ip_next_T0[104])
1387  vop112 = _mm256_fmadd_ps(
1388  vwgt,
1389  _mm256_cvtph_ps(
1390  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1391  vop112);
1392  // skip unecassery prefetch of (&ip_next_T0[112])
1393  vop120 = _mm256_fmadd_ps(
1394  vwgt,
1395  _mm256_cvtph_ps(
1396  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1397  vop120);
1398  // skip unecassery prefetch of (&ip_next_T0[120])
1399  }
1400  if (normalize_by_lengths == false) {
1401  _mm256_storeu_ps(&op[0], vop0);
1402  _mm256_storeu_ps(&op[8], vop8);
1403  _mm256_storeu_ps(&op[16], vop16);
1404  _mm256_storeu_ps(&op[24], vop24);
1405  _mm256_storeu_ps(&op[32], vop32);
1406  _mm256_storeu_ps(&op[40], vop40);
1407  _mm256_storeu_ps(&op[48], vop48);
1408  _mm256_storeu_ps(&op[56], vop56);
1409  _mm256_storeu_ps(&op[64], vop64);
1410  _mm256_storeu_ps(&op[72], vop72);
1411  _mm256_storeu_ps(&op[80], vop80);
1412  _mm256_storeu_ps(&op[88], vop88);
1413  _mm256_storeu_ps(&op[96], vop96);
1414  _mm256_storeu_ps(&op[104], vop104);
1415  _mm256_storeu_ps(&op[112], vop112);
1416  _mm256_storeu_ps(&op[120], vop120);
1417  } else if (lengths[rangeIndex]) {
1418  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1419  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1420  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1421  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1422  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1423  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1424  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1425  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1426  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1427  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1428  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1429  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1430  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1431  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1432  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1433  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1434  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1435  }
1436  }
1437  } else if (block_size == 64) {
1438  // unrolling 8 times
1439  int64_t dataInd = 0;
1440  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1441  float* op = &out[rangeIndex * block_size];
1442  __m256 vop0 = _mm256_setzero_ps();
1443  __m256 vop8 = _mm256_setzero_ps();
1444  __m256 vop16 = _mm256_setzero_ps();
1445  __m256 vop24 = _mm256_setzero_ps();
1446  __m256 vop32 = _mm256_setzero_ps();
1447  __m256 vop40 = _mm256_setzero_ps();
1448  __m256 vop48 = _mm256_setzero_ps();
1449  __m256 vop56 = _mm256_setzero_ps();
1450  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1451  ++dataInd) {
1452  const int64_t idx = indices[dataInd];
1453  CAFFE_ENFORCE(
1454  idx >= 0 && idx < data_size,
1455  "Index ",
1456  dataInd,
1457  " is out of bounds: ",
1458  idx,
1459  ", range 0 to ",
1460  data_size);
1461  float wgt = 1.f;
1462  if (weights) {
1463  wgt = weights[dataInd];
1464  }
1465  __m256 vwgt = _mm256_set1_ps(wgt);
1466  const float16* ip = &input[idx * block_size];
1467  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1468  ? (dataInd + prefdist_T0)
1469  : dataInd;
1470  const int64_t idx_pref_T0 = indices[next_T0];
1471  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1472  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1473  vop0 = _mm256_fmadd_ps(
1474  vwgt,
1475  _mm256_cvtph_ps(
1476  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1477  vop0);
1478  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1479  vop8 = _mm256_fmadd_ps(
1480  vwgt,
1481  _mm256_cvtph_ps(
1482  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1483  vop8);
1484  // skip unecassery prefetch of (&ip_next_T0[8])
1485  vop16 = _mm256_fmadd_ps(
1486  vwgt,
1487  _mm256_cvtph_ps(
1488  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1489  vop16);
1490  // skip unecassery prefetch of (&ip_next_T0[16])
1491  vop24 = _mm256_fmadd_ps(
1492  vwgt,
1493  _mm256_cvtph_ps(
1494  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1495  vop24);
1496  // skip unecassery prefetch of (&ip_next_T0[24])
1497  vop32 = _mm256_fmadd_ps(
1498  vwgt,
1499  _mm256_cvtph_ps(
1500  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1501  vop32);
1502  _mm_prefetch((&ip_next_T0[32]), _MM_HINT_T0);
1503  vop40 = _mm256_fmadd_ps(
1504  vwgt,
1505  _mm256_cvtph_ps(
1506  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1507  vop40);
1508  // skip unecassery prefetch of (&ip_next_T0[40])
1509  vop48 = _mm256_fmadd_ps(
1510  vwgt,
1511  _mm256_cvtph_ps(
1512  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1513  vop48);
1514  // skip unecassery prefetch of (&ip_next_T0[48])
1515  vop56 = _mm256_fmadd_ps(
1516  vwgt,
1517  _mm256_cvtph_ps(
1518  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1519  vop56);
1520  // skip unecassery prefetch of (&ip_next_T0[56])
1521  }
1522  if (normalize_by_lengths == false) {
1523  _mm256_storeu_ps(&op[0], vop0);
1524  _mm256_storeu_ps(&op[8], vop8);
1525  _mm256_storeu_ps(&op[16], vop16);
1526  _mm256_storeu_ps(&op[24], vop24);
1527  _mm256_storeu_ps(&op[32], vop32);
1528  _mm256_storeu_ps(&op[40], vop40);
1529  _mm256_storeu_ps(&op[48], vop48);
1530  _mm256_storeu_ps(&op[56], vop56);
1531  } else if (lengths[rangeIndex]) {
1532  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1533  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1534  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1535  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1536  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1537  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1538  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1539  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1540  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1541  }
1542  }
1543  } else if (block_size == 32) {
1544  // unrolling 4 times
1545  int64_t dataInd = 0;
1546  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1547  float* op = &out[rangeIndex * block_size];
1548  __m256 vop0 = _mm256_setzero_ps();
1549  __m256 vop8 = _mm256_setzero_ps();
1550  __m256 vop16 = _mm256_setzero_ps();
1551  __m256 vop24 = _mm256_setzero_ps();
1552  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1553  ++dataInd) {
1554  const int64_t idx = indices[dataInd];
1555  CAFFE_ENFORCE(
1556  idx >= 0 && idx < data_size,
1557  "Index ",
1558  dataInd,
1559  " is out of bounds: ",
1560  idx,
1561  ", range 0 to ",
1562  data_size);
1563  float wgt = 1.f;
1564  if (weights) {
1565  wgt = weights[dataInd];
1566  }
1567  __m256 vwgt = _mm256_set1_ps(wgt);
1568  const float16* ip = &input[idx * block_size];
1569  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1570  ? (dataInd + prefdist_T0)
1571  : dataInd;
1572  const int64_t idx_pref_T0 = indices[next_T0];
1573  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1574  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1575  vop0 = _mm256_fmadd_ps(
1576  vwgt,
1577  _mm256_cvtph_ps(
1578  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1579  vop0);
1580  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1581  vop8 = _mm256_fmadd_ps(
1582  vwgt,
1583  _mm256_cvtph_ps(
1584  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1585  vop8);
1586  // skip unecassery prefetch of (&ip_next_T0[8])
1587  vop16 = _mm256_fmadd_ps(
1588  vwgt,
1589  _mm256_cvtph_ps(
1590  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1591  vop16);
1592  // skip unecassery prefetch of (&ip_next_T0[16])
1593  vop24 = _mm256_fmadd_ps(
1594  vwgt,
1595  _mm256_cvtph_ps(
1596  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1597  vop24);
1598  // skip unecassery prefetch of (&ip_next_T0[24])
1599  }
1600  if (normalize_by_lengths == false) {
1601  _mm256_storeu_ps(&op[0], vop0);
1602  _mm256_storeu_ps(&op[8], vop8);
1603  _mm256_storeu_ps(&op[16], vop16);
1604  _mm256_storeu_ps(&op[24], vop24);
1605  } else if (lengths[rangeIndex]) {
1606  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1607  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1608  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1609  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1610  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1611  }
1612  }
1613  } else if (block_size == 16) {
1614  // unrolling 2 times
1615  int64_t dataInd = 0;
1616  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1617  float* op = &out[rangeIndex * block_size];
1618  __m256 vop0 = _mm256_setzero_ps();
1619  __m256 vop8 = _mm256_setzero_ps();
1620  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1621  ++dataInd) {
1622  const int64_t idx = indices[dataInd];
1623  CAFFE_ENFORCE(
1624  idx >= 0 && idx < data_size,
1625  "Index ",
1626  dataInd,
1627  " is out of bounds: ",
1628  idx,
1629  ", range 0 to ",
1630  data_size);
1631  float wgt = 1.f;
1632  if (weights) {
1633  wgt = weights[dataInd];
1634  }
1635  __m256 vwgt = _mm256_set1_ps(wgt);
1636  const float16* ip = &input[idx * block_size];
1637  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1638  ? (dataInd + prefdist_T0)
1639  : dataInd;
1640  const int64_t idx_pref_T0 = indices[next_T0];
1641  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1642  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1643  vop0 = _mm256_fmadd_ps(
1644  vwgt,
1645  _mm256_cvtph_ps(
1646  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1647  vop0);
1648  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1649  vop8 = _mm256_fmadd_ps(
1650  vwgt,
1651  _mm256_cvtph_ps(
1652  _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1653  vop8);
1654  // skip unecassery prefetch of (&ip_next_T0[8])
1655  }
1656  if (normalize_by_lengths == false) {
1657  _mm256_storeu_ps(&op[0], vop0);
1658  _mm256_storeu_ps(&op[8], vop8);
1659  } else if (lengths[rangeIndex]) {
1660  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1661  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1662  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1663  }
1664  }
1665  } else {
1666  // generic code
1667  int64_t dataInd = 0;
1668  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1669  float* op = &out[rangeIndex * block_size];
1670  TIndex j = 0;
1671  for (; j + 8 <= block_size; j += 8) {
1672  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1673  }
1674  for (; j < block_size; j++) {
1675  op[j] = 0.0f;
1676  }
1677  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
1678  ++dataInd) {
1679  const int64_t idx = indices[dataInd];
1680  CAFFE_ENFORCE(
1681  idx >= 0 && idx < data_size,
1682  "Index ",
1683  dataInd,
1684  " is out of bounds: ",
1685  idx,
1686  ", range 0 to ",
1687  data_size);
1688  float wgt = 1.f;
1689  if (weights) {
1690  wgt = weights[dataInd];
1691  }
1692  __m256 vwgt = _mm256_set1_ps(wgt);
1693  const float16* ip = &input[idx * block_size];
1694  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1695  ? (dataInd + prefdist_T0)
1696  : dataInd;
1697  const int64_t idx_pref_T0 = indices[next_T0];
1698  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1699  const float16* ip_next_T0 = &input[idx_pref_T0 * block_size];
1700  j = 0;
1701  for (; j + 8 <= block_size; j += 8) {
1702  _mm256_storeu_ps(
1703  &op[j],
1704  _mm256_fmadd_ps(
1705  vwgt,
1706  _mm256_cvtph_ps(_mm_loadu_si128(
1707  reinterpret_cast<const __m128i*>(&ip[j]))),
1708  _mm256_loadu_ps(&op[j])));
1709  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
1710  }
1711  float16 vtmp1[8] CAFFE2_ALIGNED(64);
1712  for (; j < block_size; j++) {
1713  vtmp1[0] = ip[j];
1714  __m256 vtmp2 = _mm256_cvtph_ps(*((__m128i*)vtmp1));
1715  op[j] += wgt * ((float*)(&vtmp2))[0];
1716  }
1717  }
1718  if (normalize_by_lengths && lengths[rangeIndex]) {
1719  float len_inv = 1.0f / lengths[rangeIndex];
1720  __m256 vlen_inv = _mm256_set1_ps(len_inv);
1721  j = 0;
1722  for (; j + 8 <= block_size; j += 8) {
1723  _mm256_storeu_ps(
1724  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1725  }
1726  for (; j < block_size; j++) {
1727  op[j] = len_inv * op[j];
1728  }
1729  }
1730  }
1731  }
1732 }
1733 
1734 void EmbeddingLookup_int32_t_uint8_t_float__avx2_fma(
1735  const TIndex block_size,
1736  const TIndex output_size,
1737  const TIndex index_size,
1738  const TIndex data_size,
1739  const uint8_t* input,
1740  const int32_t* indices,
1741  const int* lengths,
1742  const float* weights,
1743  const float* scale_bias,
1744  bool normalize_by_lengths,
1745  float* out) {
1746  const int32_t prefdist_T0 = 16;
1747  CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr");
1748  if (block_size == 128) {
1749  // unrolling 16 times
1750  int32_t dataInd = 0;
1751  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1752  float* op = &out[rangeIndex * block_size];
1753  __m256 vop0 = _mm256_setzero_ps();
1754  __m256 vop8 = _mm256_setzero_ps();
1755  __m256 vop16 = _mm256_setzero_ps();
1756  __m256 vop24 = _mm256_setzero_ps();
1757  __m256 vop32 = _mm256_setzero_ps();
1758  __m256 vop40 = _mm256_setzero_ps();
1759  __m256 vop48 = _mm256_setzero_ps();
1760  __m256 vop56 = _mm256_setzero_ps();
1761  __m256 vop64 = _mm256_setzero_ps();
1762  __m256 vop72 = _mm256_setzero_ps();
1763  __m256 vop80 = _mm256_setzero_ps();
1764  __m256 vop88 = _mm256_setzero_ps();
1765  __m256 vop96 = _mm256_setzero_ps();
1766  __m256 vop104 = _mm256_setzero_ps();
1767  __m256 vop112 = _mm256_setzero_ps();
1768  __m256 vop120 = _mm256_setzero_ps();
1769  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1770  ++dataInd) {
1771  const int32_t idx = indices[dataInd];
1772  CAFFE_ENFORCE(
1773  idx >= 0 && idx < data_size,
1774  "Index ",
1775  dataInd,
1776  " is out of bounds: ",
1777  idx,
1778  ", range 0 to ",
1779  data_size);
1780  float wgt = 1.f;
1781  float bio;
1782  if (weights) {
1783  wgt = weights[dataInd];
1784  }
1785  bio = wgt * scale_bias[2 * idx + 1];
1786  wgt = wgt * scale_bias[2 * idx];
1787  __m256 vbio = _mm256_set1_ps(bio);
1788  __m256 vwgt = _mm256_set1_ps(wgt);
1789  const uint8_t* ip = &input[idx * block_size];
1790  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1791  ? (dataInd + prefdist_T0)
1792  : dataInd;
1793  const int32_t idx_pref_T0 = indices[next_T0];
1794  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1795  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
1796  vop0 = _mm256_fmadd_ps(
1797  vwgt,
1798  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1799  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1800  _mm256_add_ps(vop0, vbio));
1801  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1802  vop8 = _mm256_fmadd_ps(
1803  vwgt,
1804  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1805  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1806  _mm256_add_ps(vop8, vbio));
1807  // skip unecassery prefetch of (&ip_next_T0[8])
1808  vop16 = _mm256_fmadd_ps(
1809  vwgt,
1810  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1811  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1812  _mm256_add_ps(vop16, vbio));
1813  // skip unecassery prefetch of (&ip_next_T0[16])
1814  vop24 = _mm256_fmadd_ps(
1815  vwgt,
1816  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1817  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1818  _mm256_add_ps(vop24, vbio));
1819  // skip unecassery prefetch of (&ip_next_T0[24])
1820  vop32 = _mm256_fmadd_ps(
1821  vwgt,
1822  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1823  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1824  _mm256_add_ps(vop32, vbio));
1825  // skip unecassery prefetch of (&ip_next_T0[32])
1826  vop40 = _mm256_fmadd_ps(
1827  vwgt,
1828  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1829  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
1830  _mm256_add_ps(vop40, vbio));
1831  // skip unecassery prefetch of (&ip_next_T0[40])
1832  vop48 = _mm256_fmadd_ps(
1833  vwgt,
1834  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1835  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
1836  _mm256_add_ps(vop48, vbio));
1837  // skip unecassery prefetch of (&ip_next_T0[48])
1838  vop56 = _mm256_fmadd_ps(
1839  vwgt,
1840  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1841  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
1842  _mm256_add_ps(vop56, vbio));
1843  // skip unecassery prefetch of (&ip_next_T0[56])
1844  vop64 = _mm256_fmadd_ps(
1845  vwgt,
1846  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1847  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
1848  _mm256_add_ps(vop64, vbio));
1849  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
1850  vop72 = _mm256_fmadd_ps(
1851  vwgt,
1852  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1853  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
1854  _mm256_add_ps(vop72, vbio));
1855  // skip unecassery prefetch of (&ip_next_T0[72])
1856  vop80 = _mm256_fmadd_ps(
1857  vwgt,
1858  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1859  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
1860  _mm256_add_ps(vop80, vbio));
1861  // skip unecassery prefetch of (&ip_next_T0[80])
1862  vop88 = _mm256_fmadd_ps(
1863  vwgt,
1864  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1865  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
1866  _mm256_add_ps(vop88, vbio));
1867  // skip unecassery prefetch of (&ip_next_T0[88])
1868  vop96 = _mm256_fmadd_ps(
1869  vwgt,
1870  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1871  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
1872  _mm256_add_ps(vop96, vbio));
1873  // skip unecassery prefetch of (&ip_next_T0[96])
1874  vop104 = _mm256_fmadd_ps(
1875  vwgt,
1876  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1877  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
1878  _mm256_add_ps(vop104, vbio));
1879  // skip unecassery prefetch of (&ip_next_T0[104])
1880  vop112 = _mm256_fmadd_ps(
1881  vwgt,
1882  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1883  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
1884  _mm256_add_ps(vop112, vbio));
1885  // skip unecassery prefetch of (&ip_next_T0[112])
1886  vop120 = _mm256_fmadd_ps(
1887  vwgt,
1888  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1889  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
1890  _mm256_add_ps(vop120, vbio));
1891  // skip unecassery prefetch of (&ip_next_T0[120])
1892  }
1893  if (normalize_by_lengths == false) {
1894  _mm256_storeu_ps(&op[0], vop0);
1895  _mm256_storeu_ps(&op[8], vop8);
1896  _mm256_storeu_ps(&op[16], vop16);
1897  _mm256_storeu_ps(&op[24], vop24);
1898  _mm256_storeu_ps(&op[32], vop32);
1899  _mm256_storeu_ps(&op[40], vop40);
1900  _mm256_storeu_ps(&op[48], vop48);
1901  _mm256_storeu_ps(&op[56], vop56);
1902  _mm256_storeu_ps(&op[64], vop64);
1903  _mm256_storeu_ps(&op[72], vop72);
1904  _mm256_storeu_ps(&op[80], vop80);
1905  _mm256_storeu_ps(&op[88], vop88);
1906  _mm256_storeu_ps(&op[96], vop96);
1907  _mm256_storeu_ps(&op[104], vop104);
1908  _mm256_storeu_ps(&op[112], vop112);
1909  _mm256_storeu_ps(&op[120], vop120);
1910  } else if (lengths[rangeIndex]) {
1911  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
1912  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1913  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1914  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1915  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1916  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1917  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1918  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1919  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1920  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1921  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1922  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1923  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1924  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1925  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1926  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1927  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1928  }
1929  }
1930  } else if (block_size == 64) {
1931  // unrolling 8 times
1932  int32_t dataInd = 0;
1933  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1934  float* op = &out[rangeIndex * block_size];
1935  __m256 vop0 = _mm256_setzero_ps();
1936  __m256 vop8 = _mm256_setzero_ps();
1937  __m256 vop16 = _mm256_setzero_ps();
1938  __m256 vop24 = _mm256_setzero_ps();
1939  __m256 vop32 = _mm256_setzero_ps();
1940  __m256 vop40 = _mm256_setzero_ps();
1941  __m256 vop48 = _mm256_setzero_ps();
1942  __m256 vop56 = _mm256_setzero_ps();
1943  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
1944  ++dataInd) {
1945  const int32_t idx = indices[dataInd];
1946  CAFFE_ENFORCE(
1947  idx >= 0 && idx < data_size,
1948  "Index ",
1949  dataInd,
1950  " is out of bounds: ",
1951  idx,
1952  ", range 0 to ",
1953  data_size);
1954  float wgt = 1.f;
1955  float bio;
1956  if (weights) {
1957  wgt = weights[dataInd];
1958  }
1959  bio = wgt * scale_bias[2 * idx + 1];
1960  wgt = wgt * scale_bias[2 * idx];
1961  __m256 vbio = _mm256_set1_ps(bio);
1962  __m256 vwgt = _mm256_set1_ps(wgt);
1963  const uint8_t* ip = &input[idx * block_size];
1964  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
1965  ? (dataInd + prefdist_T0)
1966  : dataInd;
1967  const int32_t idx_pref_T0 = indices[next_T0];
1968  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
1969  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
1970  vop0 = _mm256_fmadd_ps(
1971  vwgt,
1972  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1973  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
1974  _mm256_add_ps(vop0, vbio));
1975  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
1976  vop8 = _mm256_fmadd_ps(
1977  vwgt,
1978  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1979  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
1980  _mm256_add_ps(vop8, vbio));
1981  // skip unecassery prefetch of (&ip_next_T0[8])
1982  vop16 = _mm256_fmadd_ps(
1983  vwgt,
1984  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1985  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
1986  _mm256_add_ps(vop16, vbio));
1987  // skip unecassery prefetch of (&ip_next_T0[16])
1988  vop24 = _mm256_fmadd_ps(
1989  vwgt,
1990  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1991  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
1992  _mm256_add_ps(vop24, vbio));
1993  // skip unecassery prefetch of (&ip_next_T0[24])
1994  vop32 = _mm256_fmadd_ps(
1995  vwgt,
1996  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
1997  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
1998  _mm256_add_ps(vop32, vbio));
1999  // skip unecassery prefetch of (&ip_next_T0[32])
2000  vop40 = _mm256_fmadd_ps(
2001  vwgt,
2002  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2003  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2004  _mm256_add_ps(vop40, vbio));
2005  // skip unecassery prefetch of (&ip_next_T0[40])
2006  vop48 = _mm256_fmadd_ps(
2007  vwgt,
2008  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2009  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2010  _mm256_add_ps(vop48, vbio));
2011  // skip unecassery prefetch of (&ip_next_T0[48])
2012  vop56 = _mm256_fmadd_ps(
2013  vwgt,
2014  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2015  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2016  _mm256_add_ps(vop56, vbio));
2017  // skip unecassery prefetch of (&ip_next_T0[56])
2018  }
2019  if (normalize_by_lengths == false) {
2020  _mm256_storeu_ps(&op[0], vop0);
2021  _mm256_storeu_ps(&op[8], vop8);
2022  _mm256_storeu_ps(&op[16], vop16);
2023  _mm256_storeu_ps(&op[24], vop24);
2024  _mm256_storeu_ps(&op[32], vop32);
2025  _mm256_storeu_ps(&op[40], vop40);
2026  _mm256_storeu_ps(&op[48], vop48);
2027  _mm256_storeu_ps(&op[56], vop56);
2028  } else if (lengths[rangeIndex]) {
2029  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2030  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2031  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2032  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2033  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2034  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2035  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2036  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2037  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2038  }
2039  }
2040  } else if (block_size == 32) {
2041  // unrolling 4 times
2042  int32_t dataInd = 0;
2043  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2044  float* op = &out[rangeIndex * block_size];
2045  __m256 vop0 = _mm256_setzero_ps();
2046  __m256 vop8 = _mm256_setzero_ps();
2047  __m256 vop16 = _mm256_setzero_ps();
2048  __m256 vop24 = _mm256_setzero_ps();
2049  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2050  ++dataInd) {
2051  const int32_t idx = indices[dataInd];
2052  CAFFE_ENFORCE(
2053  idx >= 0 && idx < data_size,
2054  "Index ",
2055  dataInd,
2056  " is out of bounds: ",
2057  idx,
2058  ", range 0 to ",
2059  data_size);
2060  float wgt = 1.f;
2061  float bio;
2062  if (weights) {
2063  wgt = weights[dataInd];
2064  }
2065  bio = wgt * scale_bias[2 * idx + 1];
2066  wgt = wgt * scale_bias[2 * idx];
2067  __m256 vbio = _mm256_set1_ps(bio);
2068  __m256 vwgt = _mm256_set1_ps(wgt);
2069  const uint8_t* ip = &input[idx * block_size];
2070  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2071  ? (dataInd + prefdist_T0)
2072  : dataInd;
2073  const int32_t idx_pref_T0 = indices[next_T0];
2074  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2075  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2076  vop0 = _mm256_fmadd_ps(
2077  vwgt,
2078  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2079  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2080  _mm256_add_ps(vop0, vbio));
2081  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2082  vop8 = _mm256_fmadd_ps(
2083  vwgt,
2084  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2085  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2086  _mm256_add_ps(vop8, vbio));
2087  // skip unecassery prefetch of (&ip_next_T0[8])
2088  vop16 = _mm256_fmadd_ps(
2089  vwgt,
2090  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2091  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2092  _mm256_add_ps(vop16, vbio));
2093  // skip unecassery prefetch of (&ip_next_T0[16])
2094  vop24 = _mm256_fmadd_ps(
2095  vwgt,
2096  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2097  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2098  _mm256_add_ps(vop24, vbio));
2099  // skip unecassery prefetch of (&ip_next_T0[24])
2100  }
2101  if (normalize_by_lengths == false) {
2102  _mm256_storeu_ps(&op[0], vop0);
2103  _mm256_storeu_ps(&op[8], vop8);
2104  _mm256_storeu_ps(&op[16], vop16);
2105  _mm256_storeu_ps(&op[24], vop24);
2106  } else if (lengths[rangeIndex]) {
2107  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2108  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2109  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2110  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2111  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2112  }
2113  }
2114  } else if (block_size == 16) {
2115  // unrolling 2 times
2116  int32_t dataInd = 0;
2117  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2118  float* op = &out[rangeIndex * block_size];
2119  __m256 vop0 = _mm256_setzero_ps();
2120  __m256 vop8 = _mm256_setzero_ps();
2121  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2122  ++dataInd) {
2123  const int32_t idx = indices[dataInd];
2124  CAFFE_ENFORCE(
2125  idx >= 0 && idx < data_size,
2126  "Index ",
2127  dataInd,
2128  " is out of bounds: ",
2129  idx,
2130  ", range 0 to ",
2131  data_size);
2132  float wgt = 1.f;
2133  float bio;
2134  if (weights) {
2135  wgt = weights[dataInd];
2136  }
2137  bio = wgt * scale_bias[2 * idx + 1];
2138  wgt = wgt * scale_bias[2 * idx];
2139  __m256 vbio = _mm256_set1_ps(bio);
2140  __m256 vwgt = _mm256_set1_ps(wgt);
2141  const uint8_t* ip = &input[idx * block_size];
2142  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2143  ? (dataInd + prefdist_T0)
2144  : dataInd;
2145  const int32_t idx_pref_T0 = indices[next_T0];
2146  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2147  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2148  vop0 = _mm256_fmadd_ps(
2149  vwgt,
2150  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2151  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2152  _mm256_add_ps(vop0, vbio));
2153  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2154  vop8 = _mm256_fmadd_ps(
2155  vwgt,
2156  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2157  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2158  _mm256_add_ps(vop8, vbio));
2159  // skip unecassery prefetch of (&ip_next_T0[8])
2160  }
2161  if (normalize_by_lengths == false) {
2162  _mm256_storeu_ps(&op[0], vop0);
2163  _mm256_storeu_ps(&op[8], vop8);
2164  } else if (lengths[rangeIndex]) {
2165  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2166  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2167  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2168  }
2169  }
2170  } else {
2171  // generic code
2172  int32_t dataInd = 0;
2173  for (int32_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2174  float* op = &out[rangeIndex * block_size];
2175  TIndex j = 0;
2176  for (; j + 8 <= block_size; j += 8) {
2177  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2178  }
2179  for (; j < block_size; j++) {
2180  op[j] = 0.0f;
2181  }
2182  for (int32_t start = dataInd; dataInd < start + lengths[rangeIndex];
2183  ++dataInd) {
2184  const int32_t idx = indices[dataInd];
2185  CAFFE_ENFORCE(
2186  idx >= 0 && idx < data_size,
2187  "Index ",
2188  dataInd,
2189  " is out of bounds: ",
2190  idx,
2191  ", range 0 to ",
2192  data_size);
2193  float wgt = 1.f;
2194  float bio;
2195  if (weights) {
2196  wgt = weights[dataInd];
2197  }
2198  assert(scale_bias);
2199  bio = wgt * scale_bias[2 * idx + 1];
2200  wgt = wgt * scale_bias[2 * idx];
2201  __m256 vbio = _mm256_set1_ps(bio);
2202  __m256 vwgt = _mm256_set1_ps(wgt);
2203  const uint8_t* ip = &input[idx * block_size];
2204  const int32_t next_T0 = (dataInd < index_size - prefdist_T0)
2205  ? (dataInd + prefdist_T0)
2206  : dataInd;
2207  const int32_t idx_pref_T0 = indices[next_T0];
2208  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2209  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2210  j = 0;
2211  for (; j + 8 <= block_size; j += 8) {
2212  _mm256_storeu_ps(
2213  &op[j],
2214  _mm256_fmadd_ps(
2215  vwgt,
2216  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2217  reinterpret_cast<const __m128i*>(&ip[j])))),
2218  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2219  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2220  }
2221  for (; j < block_size; j++) {
2222  op[j] += wgt * ((float)ip[j]) + bio;
2223  }
2224  }
2225  if (normalize_by_lengths && lengths[rangeIndex]) {
2226  float len_inv = 1.0f / lengths[rangeIndex];
2227  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2228  j = 0;
2229  for (; j + 8 <= block_size; j += 8) {
2230  _mm256_storeu_ps(
2231  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2232  }
2233  for (; j < block_size; j++) {
2234  op[j] = len_inv * op[j];
2235  }
2236  }
2237  }
2238  }
2239 }
2240 
2241 void EmbeddingLookup_int64_t_uint8_t_float__avx2_fma(
2242  const TIndex block_size,
2243  const TIndex output_size,
2244  const TIndex index_size,
2245  const TIndex data_size,
2246  const uint8_t* input,
2247  const int64_t* indices,
2248  const int* lengths,
2249  const float* weights,
2250  const float* scale_bias,
2251  bool normalize_by_lengths,
2252  float* out) {
2253  const int64_t prefdist_T0 = 16;
2254  CAFFE_ENFORCE(scale_bias != nullptr, "scale_bias must not be nullptr");
2255  if (block_size == 128) {
2256  // unrolling 16 times
2257  int64_t dataInd = 0;
2258  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2259  float* op = &out[rangeIndex * block_size];
2260  __m256 vop0 = _mm256_setzero_ps();
2261  __m256 vop8 = _mm256_setzero_ps();
2262  __m256 vop16 = _mm256_setzero_ps();
2263  __m256 vop24 = _mm256_setzero_ps();
2264  __m256 vop32 = _mm256_setzero_ps();
2265  __m256 vop40 = _mm256_setzero_ps();
2266  __m256 vop48 = _mm256_setzero_ps();
2267  __m256 vop56 = _mm256_setzero_ps();
2268  __m256 vop64 = _mm256_setzero_ps();
2269  __m256 vop72 = _mm256_setzero_ps();
2270  __m256 vop80 = _mm256_setzero_ps();
2271  __m256 vop88 = _mm256_setzero_ps();
2272  __m256 vop96 = _mm256_setzero_ps();
2273  __m256 vop104 = _mm256_setzero_ps();
2274  __m256 vop112 = _mm256_setzero_ps();
2275  __m256 vop120 = _mm256_setzero_ps();
2276  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2277  ++dataInd) {
2278  const int64_t idx = indices[dataInd];
2279  CAFFE_ENFORCE(
2280  idx >= 0 && idx < data_size,
2281  "Index ",
2282  dataInd,
2283  " is out of bounds: ",
2284  idx,
2285  ", range 0 to ",
2286  data_size);
2287  float wgt = 1.f;
2288  float bio;
2289  if (weights) {
2290  wgt = weights[dataInd];
2291  }
2292  bio = wgt * scale_bias[2 * idx + 1];
2293  wgt = wgt * scale_bias[2 * idx];
2294  __m256 vbio = _mm256_set1_ps(bio);
2295  __m256 vwgt = _mm256_set1_ps(wgt);
2296  const uint8_t* ip = &input[idx * block_size];
2297  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2298  ? (dataInd + prefdist_T0)
2299  : dataInd;
2300  const int64_t idx_pref_T0 = indices[next_T0];
2301  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2302  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2303  vop0 = _mm256_fmadd_ps(
2304  vwgt,
2305  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2306  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2307  _mm256_add_ps(vop0, vbio));
2308  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2309  vop8 = _mm256_fmadd_ps(
2310  vwgt,
2311  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2312  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2313  _mm256_add_ps(vop8, vbio));
2314  // skip unecassery prefetch of (&ip_next_T0[8])
2315  vop16 = _mm256_fmadd_ps(
2316  vwgt,
2317  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2318  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2319  _mm256_add_ps(vop16, vbio));
2320  // skip unecassery prefetch of (&ip_next_T0[16])
2321  vop24 = _mm256_fmadd_ps(
2322  vwgt,
2323  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2324  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2325  _mm256_add_ps(vop24, vbio));
2326  // skip unecassery prefetch of (&ip_next_T0[24])
2327  vop32 = _mm256_fmadd_ps(
2328  vwgt,
2329  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2330  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2331  _mm256_add_ps(vop32, vbio));
2332  // skip unecassery prefetch of (&ip_next_T0[32])
2333  vop40 = _mm256_fmadd_ps(
2334  vwgt,
2335  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2336  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2337  _mm256_add_ps(vop40, vbio));
2338  // skip unecassery prefetch of (&ip_next_T0[40])
2339  vop48 = _mm256_fmadd_ps(
2340  vwgt,
2341  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2342  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2343  _mm256_add_ps(vop48, vbio));
2344  // skip unecassery prefetch of (&ip_next_T0[48])
2345  vop56 = _mm256_fmadd_ps(
2346  vwgt,
2347  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2348  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2349  _mm256_add_ps(vop56, vbio));
2350  // skip unecassery prefetch of (&ip_next_T0[56])
2351  vop64 = _mm256_fmadd_ps(
2352  vwgt,
2353  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2354  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
2355  _mm256_add_ps(vop64, vbio));
2356  _mm_prefetch((&ip_next_T0[64]), _MM_HINT_T0);
2357  vop72 = _mm256_fmadd_ps(
2358  vwgt,
2359  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2360  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
2361  _mm256_add_ps(vop72, vbio));
2362  // skip unecassery prefetch of (&ip_next_T0[72])
2363  vop80 = _mm256_fmadd_ps(
2364  vwgt,
2365  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2366  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
2367  _mm256_add_ps(vop80, vbio));
2368  // skip unecassery prefetch of (&ip_next_T0[80])
2369  vop88 = _mm256_fmadd_ps(
2370  vwgt,
2371  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2372  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
2373  _mm256_add_ps(vop88, vbio));
2374  // skip unecassery prefetch of (&ip_next_T0[88])
2375  vop96 = _mm256_fmadd_ps(
2376  vwgt,
2377  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2378  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
2379  _mm256_add_ps(vop96, vbio));
2380  // skip unecassery prefetch of (&ip_next_T0[96])
2381  vop104 = _mm256_fmadd_ps(
2382  vwgt,
2383  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2384  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
2385  _mm256_add_ps(vop104, vbio));
2386  // skip unecassery prefetch of (&ip_next_T0[104])
2387  vop112 = _mm256_fmadd_ps(
2388  vwgt,
2389  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2390  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
2391  _mm256_add_ps(vop112, vbio));
2392  // skip unecassery prefetch of (&ip_next_T0[112])
2393  vop120 = _mm256_fmadd_ps(
2394  vwgt,
2395  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2396  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
2397  _mm256_add_ps(vop120, vbio));
2398  // skip unecassery prefetch of (&ip_next_T0[120])
2399  }
2400  if (normalize_by_lengths == false) {
2401  _mm256_storeu_ps(&op[0], vop0);
2402  _mm256_storeu_ps(&op[8], vop8);
2403  _mm256_storeu_ps(&op[16], vop16);
2404  _mm256_storeu_ps(&op[24], vop24);
2405  _mm256_storeu_ps(&op[32], vop32);
2406  _mm256_storeu_ps(&op[40], vop40);
2407  _mm256_storeu_ps(&op[48], vop48);
2408  _mm256_storeu_ps(&op[56], vop56);
2409  _mm256_storeu_ps(&op[64], vop64);
2410  _mm256_storeu_ps(&op[72], vop72);
2411  _mm256_storeu_ps(&op[80], vop80);
2412  _mm256_storeu_ps(&op[88], vop88);
2413  _mm256_storeu_ps(&op[96], vop96);
2414  _mm256_storeu_ps(&op[104], vop104);
2415  _mm256_storeu_ps(&op[112], vop112);
2416  _mm256_storeu_ps(&op[120], vop120);
2417  } else if (lengths[rangeIndex]) {
2418  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2419  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2420  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2421  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2422  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2423  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2424  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2425  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2426  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2427  _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2428  _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2429  _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2430  _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2431  _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2432  _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2433  _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2434  _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2435  }
2436  }
2437  } else if (block_size == 64) {
2438  // unrolling 8 times
2439  int64_t dataInd = 0;
2440  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2441  float* op = &out[rangeIndex * block_size];
2442  __m256 vop0 = _mm256_setzero_ps();
2443  __m256 vop8 = _mm256_setzero_ps();
2444  __m256 vop16 = _mm256_setzero_ps();
2445  __m256 vop24 = _mm256_setzero_ps();
2446  __m256 vop32 = _mm256_setzero_ps();
2447  __m256 vop40 = _mm256_setzero_ps();
2448  __m256 vop48 = _mm256_setzero_ps();
2449  __m256 vop56 = _mm256_setzero_ps();
2450  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2451  ++dataInd) {
2452  const int64_t idx = indices[dataInd];
2453  CAFFE_ENFORCE(
2454  idx >= 0 && idx < data_size,
2455  "Index ",
2456  dataInd,
2457  " is out of bounds: ",
2458  idx,
2459  ", range 0 to ",
2460  data_size);
2461  float wgt = 1.f;
2462  float bio;
2463  if (weights) {
2464  wgt = weights[dataInd];
2465  }
2466  bio = wgt * scale_bias[2 * idx + 1];
2467  wgt = wgt * scale_bias[2 * idx];
2468  __m256 vbio = _mm256_set1_ps(bio);
2469  __m256 vwgt = _mm256_set1_ps(wgt);
2470  const uint8_t* ip = &input[idx * block_size];
2471  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2472  ? (dataInd + prefdist_T0)
2473  : dataInd;
2474  const int64_t idx_pref_T0 = indices[next_T0];
2475  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2476  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2477  vop0 = _mm256_fmadd_ps(
2478  vwgt,
2479  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2480  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2481  _mm256_add_ps(vop0, vbio));
2482  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2483  vop8 = _mm256_fmadd_ps(
2484  vwgt,
2485  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2486  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2487  _mm256_add_ps(vop8, vbio));
2488  // skip unecassery prefetch of (&ip_next_T0[8])
2489  vop16 = _mm256_fmadd_ps(
2490  vwgt,
2491  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2492  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2493  _mm256_add_ps(vop16, vbio));
2494  // skip unecassery prefetch of (&ip_next_T0[16])
2495  vop24 = _mm256_fmadd_ps(
2496  vwgt,
2497  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2498  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2499  _mm256_add_ps(vop24, vbio));
2500  // skip unecassery prefetch of (&ip_next_T0[24])
2501  vop32 = _mm256_fmadd_ps(
2502  vwgt,
2503  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2504  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
2505  _mm256_add_ps(vop32, vbio));
2506  // skip unecassery prefetch of (&ip_next_T0[32])
2507  vop40 = _mm256_fmadd_ps(
2508  vwgt,
2509  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2510  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
2511  _mm256_add_ps(vop40, vbio));
2512  // skip unecassery prefetch of (&ip_next_T0[40])
2513  vop48 = _mm256_fmadd_ps(
2514  vwgt,
2515  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2516  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
2517  _mm256_add_ps(vop48, vbio));
2518  // skip unecassery prefetch of (&ip_next_T0[48])
2519  vop56 = _mm256_fmadd_ps(
2520  vwgt,
2521  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2522  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
2523  _mm256_add_ps(vop56, vbio));
2524  // skip unecassery prefetch of (&ip_next_T0[56])
2525  }
2526  if (normalize_by_lengths == false) {
2527  _mm256_storeu_ps(&op[0], vop0);
2528  _mm256_storeu_ps(&op[8], vop8);
2529  _mm256_storeu_ps(&op[16], vop16);
2530  _mm256_storeu_ps(&op[24], vop24);
2531  _mm256_storeu_ps(&op[32], vop32);
2532  _mm256_storeu_ps(&op[40], vop40);
2533  _mm256_storeu_ps(&op[48], vop48);
2534  _mm256_storeu_ps(&op[56], vop56);
2535  } else if (lengths[rangeIndex]) {
2536  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2537  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2538  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2539  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2540  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2541  _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2542  _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2543  _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2544  _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2545  }
2546  }
2547  } else if (block_size == 32) {
2548  // unrolling 4 times
2549  int64_t dataInd = 0;
2550  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2551  float* op = &out[rangeIndex * block_size];
2552  __m256 vop0 = _mm256_setzero_ps();
2553  __m256 vop8 = _mm256_setzero_ps();
2554  __m256 vop16 = _mm256_setzero_ps();
2555  __m256 vop24 = _mm256_setzero_ps();
2556  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2557  ++dataInd) {
2558  const int64_t idx = indices[dataInd];
2559  CAFFE_ENFORCE(
2560  idx >= 0 && idx < data_size,
2561  "Index ",
2562  dataInd,
2563  " is out of bounds: ",
2564  idx,
2565  ", range 0 to ",
2566  data_size);
2567  float wgt = 1.f;
2568  float bio;
2569  if (weights) {
2570  wgt = weights[dataInd];
2571  }
2572  bio = wgt * scale_bias[2 * idx + 1];
2573  wgt = wgt * scale_bias[2 * idx];
2574  __m256 vbio = _mm256_set1_ps(bio);
2575  __m256 vwgt = _mm256_set1_ps(wgt);
2576  const uint8_t* ip = &input[idx * block_size];
2577  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2578  ? (dataInd + prefdist_T0)
2579  : dataInd;
2580  const int64_t idx_pref_T0 = indices[next_T0];
2581  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2582  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2583  vop0 = _mm256_fmadd_ps(
2584  vwgt,
2585  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2586  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2587  _mm256_add_ps(vop0, vbio));
2588  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2589  vop8 = _mm256_fmadd_ps(
2590  vwgt,
2591  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2592  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2593  _mm256_add_ps(vop8, vbio));
2594  // skip unecassery prefetch of (&ip_next_T0[8])
2595  vop16 = _mm256_fmadd_ps(
2596  vwgt,
2597  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2598  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
2599  _mm256_add_ps(vop16, vbio));
2600  // skip unecassery prefetch of (&ip_next_T0[16])
2601  vop24 = _mm256_fmadd_ps(
2602  vwgt,
2603  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2604  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
2605  _mm256_add_ps(vop24, vbio));
2606  // skip unecassery prefetch of (&ip_next_T0[24])
2607  }
2608  if (normalize_by_lengths == false) {
2609  _mm256_storeu_ps(&op[0], vop0);
2610  _mm256_storeu_ps(&op[8], vop8);
2611  _mm256_storeu_ps(&op[16], vop16);
2612  _mm256_storeu_ps(&op[24], vop24);
2613  } else if (lengths[rangeIndex]) {
2614  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2615  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2616  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2617  _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2618  _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2619  }
2620  }
2621  } else if (block_size == 16) {
2622  // unrolling 2 times
2623  int64_t dataInd = 0;
2624  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2625  float* op = &out[rangeIndex * block_size];
2626  __m256 vop0 = _mm256_setzero_ps();
2627  __m256 vop8 = _mm256_setzero_ps();
2628  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2629  ++dataInd) {
2630  const int64_t idx = indices[dataInd];
2631  CAFFE_ENFORCE(
2632  idx >= 0 && idx < data_size,
2633  "Index ",
2634  dataInd,
2635  " is out of bounds: ",
2636  idx,
2637  ", range 0 to ",
2638  data_size);
2639  float wgt = 1.f;
2640  float bio;
2641  if (weights) {
2642  wgt = weights[dataInd];
2643  }
2644  bio = wgt * scale_bias[2 * idx + 1];
2645  wgt = wgt * scale_bias[2 * idx];
2646  __m256 vbio = _mm256_set1_ps(bio);
2647  __m256 vwgt = _mm256_set1_ps(wgt);
2648  const uint8_t* ip = &input[idx * block_size];
2649  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2650  ? (dataInd + prefdist_T0)
2651  : dataInd;
2652  const int64_t idx_pref_T0 = indices[next_T0];
2653  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2654  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2655  vop0 = _mm256_fmadd_ps(
2656  vwgt,
2657  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2658  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
2659  _mm256_add_ps(vop0, vbio));
2660  _mm_prefetch((&ip_next_T0[0]), _MM_HINT_T0);
2661  vop8 = _mm256_fmadd_ps(
2662  vwgt,
2663  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
2664  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
2665  _mm256_add_ps(vop8, vbio));
2666  // skip unecassery prefetch of (&ip_next_T0[8])
2667  }
2668  if (normalize_by_lengths == false) {
2669  _mm256_storeu_ps(&op[0], vop0);
2670  _mm256_storeu_ps(&op[8], vop8);
2671  } else if (lengths[rangeIndex]) {
2672  __m256 vlen_inv = _mm256_set1_ps(1.0f / lengths[rangeIndex]);
2673  _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2674  _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2675  }
2676  }
2677  } else {
2678  // generic code
2679  int64_t dataInd = 0;
2680  for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2681  float* op = &out[rangeIndex * block_size];
2682  TIndex j = 0;
2683  for (; j + 8 <= block_size; j += 8) {
2684  _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2685  }
2686  for (; j < block_size; j++) {
2687  op[j] = 0.0f;
2688  }
2689  for (int64_t start = dataInd; dataInd < start + lengths[rangeIndex];
2690  ++dataInd) {
2691  const int64_t idx = indices[dataInd];
2692  CAFFE_ENFORCE(
2693  idx >= 0 && idx < data_size,
2694  "Index ",
2695  dataInd,
2696  " is out of bounds: ",
2697  idx,
2698  ", range 0 to ",
2699  data_size);
2700  float wgt = 1.f;
2701  float bio;
2702  if (weights) {
2703  wgt = weights[dataInd];
2704  }
2705  assert(scale_bias);
2706  bio = wgt * scale_bias[2 * idx + 1];
2707  wgt = wgt * scale_bias[2 * idx];
2708  __m256 vbio = _mm256_set1_ps(bio);
2709  __m256 vwgt = _mm256_set1_ps(wgt);
2710  const uint8_t* ip = &input[idx * block_size];
2711  const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2712  ? (dataInd + prefdist_T0)
2713  : dataInd;
2714  const int64_t idx_pref_T0 = indices[next_T0];
2715  CAFFE_ENFORCE(idx_pref_T0 >= 0 && idx_pref_T0 < data_size);
2716  const uint8_t* ip_next_T0 = &input[idx_pref_T0 * block_size];
2717  j = 0;
2718  for (; j + 8 <= block_size; j += 8) {
2719  _mm256_storeu_ps(
2720  &op[j],
2721  _mm256_fmadd_ps(
2722  vwgt,
2723  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
2724  reinterpret_cast<const __m128i*>(&ip[j])))),
2725  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
2726  _mm_prefetch((&ip_next_T0[j]), _MM_HINT_T0);
2727  }
2728  for (; j < block_size; j++) {
2729  op[j] += wgt * ((float)ip[j]) + bio;
2730  }
2731  }
2732  if (normalize_by_lengths && lengths[rangeIndex]) {
2733  float len_inv = 1.0f / lengths[rangeIndex];
2734  __m256 vlen_inv = _mm256_set1_ps(len_inv);
2735  j = 0;
2736  for (; j + 8 <= block_size; j += 8) {
2737  _mm256_storeu_ps(
2738  &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2739  }
2740  for (; j < block_size; j++) {
2741  op[j] = len_inv * op[j];
2742  }
2743  }
2744  }
2745  }
2746 }
2747 
2748 } // namespace caffe2
Copyright (c) 2016-present, Facebook, Inc.