Caffe2 - C++ API
A deep learning, cross platform ML framework
vec256_int.h
1 #pragma once
2 
3 #include <ATen/cpu/vec256/intrinsics.h>
4 #include <ATen/cpu/vec256/vec256_base.h>
5 
6 namespace at {
7 namespace vec256 {
8 namespace {
9 
10 #ifdef __AVX2__
11 
12 struct Vec256i {
13 protected:
14  __m256i values;
15 
16  static inline __m256i invert(const __m256i& v) {
17  const auto ones = _mm256_set1_epi64x(-1);
18  return _mm256_xor_si256(ones, v);
19  }
20 public:
21  Vec256i() {}
22  Vec256i(__m256i v) : values(v) {}
23  operator __m256i() const {
24  return values;
25  }
26 };
27 
28 template <>
29 struct Vec256<int64_t> : public Vec256i {
30  static constexpr int size() {
31  return 4;
32  }
33  using Vec256i::Vec256i;
34  Vec256() {}
35  Vec256(int64_t v) { values = _mm256_set1_epi64x(v); }
36  Vec256(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
37  values = _mm256_setr_epi64x(val1, val2, val3, val4);
38  }
39  template <int64_t mask>
40  static Vec256<int64_t> blend(Vec256<int64_t> a, Vec256<int64_t> b) {
41  __at_align32__ int64_t tmp_values[size()];
42  a.store(tmp_values);
43  if (mask & 0x01)
44  tmp_values[0] = _mm256_extract_epi64(b.values, 0);
45  if (mask & 0x02)
46  tmp_values[1] = _mm256_extract_epi64(b.values, 1);
47  if (mask & 0x04)
48  tmp_values[2] = _mm256_extract_epi64(b.values, 2);
49  if (mask & 0x08)
50  tmp_values[3] = _mm256_extract_epi64(b.values, 3);
51  return loadu(tmp_values);
52  }
53  static Vec256<int64_t> blendv(const Vec256<int64_t>& a, const Vec256<int64_t>& b,
54  const Vec256<int64_t>& mask) {
55  return _mm256_blendv_epi8(a.values, b.values, mask.values);
56  }
57  static Vec256<int64_t> arange(int64_t base = 0, int64_t step = 1) {
58  return Vec256<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
59  }
60  static Vec256<int64_t>
61  set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size()) {
62  switch (count) {
63  case 0:
64  return a;
65  case 1:
66  return blend<1>(a, b);
67  case 2:
68  return blend<3>(a, b);
69  case 3:
70  return blend<7>(a, b);
71  }
72  return b;
73  }
74  static Vec256<int64_t> loadu(const void* ptr) {
75  return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
76  }
77  static Vec256<int64_t> loadu(const void* ptr, int64_t count) {
78  __at_align32__ int64_t tmp_values[size()];
79  std::memcpy(tmp_values, ptr, count * sizeof(int64_t));
80  return loadu(tmp_values);
81  }
82  void store(void* ptr, int count = size()) const {
83  if (count == size()) {
84  _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
85  } else if (count > 0) {
86  __at_align32__ int64_t tmp_values[size()];
87  _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
88  std::memcpy(ptr, tmp_values, count * sizeof(int64_t));
89  }
90  }
91  const int64_t& operator[](int idx) const = delete;
92  int64_t& operator[](int idx) = delete;
93  Vec256<int64_t> abs() const {
94  auto zero = _mm256_set1_epi64x(0);
95  auto is_larger = _mm256_cmpgt_epi64(zero, values);
96  auto inverse = _mm256_xor_si256(values, is_larger);
97  return _mm256_sub_epi64(inverse, is_larger);
98  }
99  Vec256<int64_t> operator==(const Vec256<int64_t>& other) const {
100  return _mm256_cmpeq_epi64(values, other.values);
101  }
102  Vec256<int64_t> operator!=(const Vec256<int64_t>& other) const {
103  return invert(_mm256_cmpeq_epi64(values, other.values));
104  }
105  Vec256<int64_t> operator<(const Vec256<int64_t>& other) const {
106  return _mm256_cmpgt_epi64(other.values, values);
107  }
108  Vec256<int64_t> operator<=(const Vec256<int64_t>& other) const {
109  return invert(_mm256_cmpgt_epi64(values, other.values));
110  }
111  Vec256<int64_t> operator>(const Vec256<int64_t>& other) const {
112  return _mm256_cmpgt_epi64(values, other.values);
113  }
114  Vec256<int64_t> operator>=(const Vec256<int64_t>& other) const {
115  return invert(_mm256_cmpgt_epi64(other.values, values));
116  }
117 };
118 
119 template <>
120 struct Vec256<int32_t> : public Vec256i {
121  static constexpr int size() {
122  return 8;
123  }
124  using Vec256i::Vec256i;
125  Vec256() {}
126  Vec256(int32_t v) { values = _mm256_set1_epi32(v); }
127  Vec256(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
128  int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
129  values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
130  }
131  template <int64_t mask>
132  static Vec256<int32_t> blend(Vec256<int32_t> a, Vec256<int32_t> b) {
133  return _mm256_blend_epi32(a, b, mask);
134  }
135  static Vec256<int32_t> blendv(const Vec256<int32_t>& a, const Vec256<int32_t>& b,
136  const Vec256<int32_t>& mask) {
137  return _mm256_blendv_epi8(a.values, b.values, mask.values);
138  }
139  static Vec256<int32_t> arange(int32_t base = 0, int32_t step = 1) {
140  return Vec256<int32_t>(
141  base, base + step, base + 2 * step, base + 3 * step,
142  base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
143  }
144  static Vec256<int32_t>
145  set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size()) {
146  switch (count) {
147  case 0:
148  return a;
149  case 1:
150  return blend<1>(a, b);
151  case 2:
152  return blend<3>(a, b);
153  case 3:
154  return blend<7>(a, b);
155  case 4:
156  return blend<15>(a, b);
157  case 5:
158  return blend<31>(a, b);
159  case 6:
160  return blend<63>(a, b);
161  case 7:
162  return blend<127>(a, b);
163  }
164  return b;
165  }
166  static Vec256<int32_t> loadu(const void* ptr) {
167  return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
168  }
169  static Vec256<int32_t> loadu(const void* ptr, int32_t count) {
170  __at_align32__ int32_t tmp_values[size()];
171  std::memcpy(tmp_values, ptr, count * sizeof(int32_t));
172  return loadu(tmp_values);
173  }
174  void store(void* ptr, int count = size()) const {
175  if (count == size()) {
176  _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
177  } else if (count > 0) {
178  __at_align32__ int32_t tmp_values[size()];
179  _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
180  std::memcpy(ptr, tmp_values, count * sizeof(int32_t));
181  }
182  }
183  const int32_t& operator[](int idx) const = delete;
184  int32_t& operator[](int idx) = delete;
185  Vec256<int32_t> abs() const {
186  return _mm256_abs_epi32(values);
187  }
188  Vec256<int32_t> operator==(const Vec256<int32_t>& other) const {
189  return _mm256_cmpeq_epi32(values, other.values);
190  }
191  Vec256<int32_t> operator!=(const Vec256<int32_t>& other) const {
192  return invert(_mm256_cmpeq_epi32(values, other.values));
193  }
194  Vec256<int32_t> operator<(const Vec256<int32_t>& other) const {
195  return _mm256_cmpgt_epi32(other.values, values);
196  }
197  Vec256<int32_t> operator<=(const Vec256<int32_t>& other) const {
198  return invert(_mm256_cmpgt_epi32(values, other.values));
199  }
200  Vec256<int32_t> operator>(const Vec256<int32_t>& other) const {
201  return _mm256_cmpgt_epi32(values, other.values);
202  }
203  Vec256<int32_t> operator>=(const Vec256<int32_t>& other) const {
204  return invert(_mm256_cmpgt_epi32(other.values, values));
205  }
206 };
207 
208 template <>
209 void convert(const int32_t *src, float *dst, int64_t n) {
210  int64_t i;
211  // int32_t and float have same size
212 #ifndef _MSC_VER
213 # pragma unroll
214 #endif
215  for (i = 0; i <= (n - Vec256<int32_t>::size()); i += Vec256<int32_t>::size()) {
216  auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
217  auto output_vec = _mm256_cvtepi32_ps(input_vec);
218  _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
219  }
220 #ifndef _MSC_VER
221 # pragma unroll
222 #endif
223  for (; i < n; i++) {
224  dst[i] = static_cast<float>(src[i]);
225  }
226 }
227 
228 template <>
229 void convert(const int32_t *src, double *dst, int64_t n) {
230  int64_t i;
231  // int32_t has half the size of double
232 #ifndef _MSC_VER
233 # pragma unroll
234 #endif
235  for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
236  auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
237  auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
238  _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
239  }
240 #ifndef _MSC_VER
241 # pragma unroll
242 #endif
243  for (; i < n; i++) {
244  dst[i] = static_cast<double>(src[i]);
245  }
246 }
247 
248 template <>
249 struct Vec256<int16_t> : public Vec256i {
250  static constexpr int size() {
251  return 16;
252  }
253  using Vec256i::Vec256i;
254  Vec256() {}
255  Vec256(int16_t v) { values = _mm256_set1_epi16(v); }
256  Vec256(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
257  int16_t val5, int16_t val6, int16_t val7, int16_t val8,
258  int16_t val9, int16_t val10, int16_t val11, int16_t val12,
259  int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
260  values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
261  val9, val10, val11, val12, val13, val14, val15, val16);
262  }
263  template <int64_t mask>
264  static Vec256<int16_t> blend(Vec256<int16_t> a, Vec256<int16_t> b) {
265  __at_align32__ int16_t tmp_values[size()];
266  a.store(tmp_values);
267  if (mask & 0x01)
268  tmp_values[0] = _mm256_extract_epi16(b.values, 0);
269  if (mask & 0x02)
270  tmp_values[1] = _mm256_extract_epi16(b.values, 1);
271  if (mask & 0x04)
272  tmp_values[2] = _mm256_extract_epi16(b.values, 2);
273  if (mask & 0x08)
274  tmp_values[3] = _mm256_extract_epi16(b.values, 3);
275  if (mask & 0x10)
276  tmp_values[4] = _mm256_extract_epi16(b.values, 4);
277  if (mask & 0x20)
278  tmp_values[5] = _mm256_extract_epi16(b.values, 5);
279  if (mask & 0x40)
280  tmp_values[6] = _mm256_extract_epi16(b.values, 6);
281  if (mask & 0x80)
282  tmp_values[7] = _mm256_extract_epi16(b.values, 7);
283  if (mask & 0x100)
284  tmp_values[8] = _mm256_extract_epi16(b.values, 8);
285  if (mask & 0x200)
286  tmp_values[9] = _mm256_extract_epi16(b.values, 9);
287  if (mask & 0x400)
288  tmp_values[10] = _mm256_extract_epi16(b.values, 10);
289  if (mask & 0x800)
290  tmp_values[11] = _mm256_extract_epi16(b.values, 11);
291  if (mask & 0x1000)
292  tmp_values[12] = _mm256_extract_epi16(b.values, 12);
293  if (mask & 0x2000)
294  tmp_values[13] = _mm256_extract_epi16(b.values, 13);
295  if (mask & 0x4000)
296  tmp_values[14] = _mm256_extract_epi16(b.values, 14);
297  if (mask & 0x8000)
298  tmp_values[15] = _mm256_extract_epi16(b.values, 15);
299  return loadu(tmp_values);
300  }
301  static Vec256<int16_t> blendv(const Vec256<int16_t>& a, const Vec256<int16_t>& b,
302  const Vec256<int16_t>& mask) {
303  return _mm256_blendv_epi8(a.values, b.values, mask.values);
304  }
305  static Vec256<int16_t> arange(int16_t base = 0, int16_t step = 1) {
306  return Vec256<int16_t>(
307  base, base + step, base + 2 * step, base + 3 * step,
308  base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
309  base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
310  base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
311  }
312  static Vec256<int16_t>
313  set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size()) {
314  switch (count) {
315  case 0:
316  return a;
317  case 1:
318  return blend<1>(a, b);
319  case 2:
320  return blend<3>(a, b);
321  case 3:
322  return blend<7>(a, b);
323  case 4:
324  return blend<15>(a, b);
325  case 5:
326  return blend<31>(a, b);
327  case 6:
328  return blend<63>(a, b);
329  case 7:
330  return blend<127>(a, b);
331  case 8:
332  return blend<255>(a, b);
333  case 9:
334  return blend<511>(a, b);
335  case 10:
336  return blend<1023>(a, b);
337  case 11:
338  return blend<2047>(a, b);
339  case 12:
340  return blend<4095>(a, b);
341  case 13:
342  return blend<8191>(a, b);
343  case 14:
344  return blend<16383>(a, b);
345  case 15:
346  return blend<32767>(a, b);
347  }
348  return b;
349  }
350  static Vec256<int16_t> loadu(const void* ptr) {
351  return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
352  }
353  static Vec256<int16_t> loadu(const void* ptr, int16_t count) {
354  __at_align32__ int16_t tmp_values[size()];
355  std::memcpy(tmp_values, ptr, count * sizeof(int16_t));
356  return loadu(tmp_values);
357  }
358  void store(void* ptr, int count = size()) const {
359  if (count == size()) {
360  _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
361  } else if (count > 0) {
362  __at_align32__ int16_t tmp_values[size()];
363  _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
364  std::memcpy(ptr, tmp_values, count * sizeof(int16_t));
365  }
366  }
367  const int16_t& operator[](int idx) const = delete;
368  int16_t& operator[](int idx) = delete;
369  Vec256<int16_t> abs() const {
370  return _mm256_abs_epi16(values);
371  }
372  Vec256<int16_t> operator==(const Vec256<int16_t>& other) const {
373  return _mm256_cmpeq_epi16(values, other.values);
374  }
375  Vec256<int16_t> operator!=(const Vec256<int16_t>& other) const {
376  return invert(_mm256_cmpeq_epi16(values, other.values));
377  }
378  Vec256<int16_t> operator<(const Vec256<int16_t>& other) const {
379  return _mm256_cmpgt_epi16(other.values, values);
380  }
381  Vec256<int16_t> operator<=(const Vec256<int16_t>& other) const {
382  return invert(_mm256_cmpgt_epi16(values, other.values));
383  }
384  Vec256<int16_t> operator>(const Vec256<int16_t>& other) const {
385  return _mm256_cmpgt_epi16(values, other.values);
386  }
387  Vec256<int16_t> operator>=(const Vec256<int16_t>& other) const {
388  return invert(_mm256_cmpgt_epi16(other.values, values));
389  }
390 };
391 
392 template <>
393 Vec256<int64_t> inline operator+(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
394  return _mm256_add_epi64(a, b);
395 }
396 
397 template <>
398 Vec256<int32_t> inline operator+(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
399  return _mm256_add_epi32(a, b);
400 }
401 
402 template <>
403 Vec256<int16_t> inline operator+(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
404  return _mm256_add_epi16(a, b);
405 }
406 
407 template <>
408 Vec256<int64_t> inline operator-(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
409  return _mm256_sub_epi64(a, b);
410 }
411 
412 template <>
413 Vec256<int32_t> inline operator-(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
414  return _mm256_sub_epi32(a, b);
415 }
416 
417 template <>
418 Vec256<int16_t> inline operator-(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
419  return _mm256_sub_epi16(a, b);
420 }
421 
422 // Emulate operations with no native 64-bit support in avx,
423 // by extracting each element, performing the operation pointwise,
424 // then combining the results into a vector.
425 template <typename op_t>
426 Vec256<int64_t> inline emulate(const Vec256<int64_t>& a, const Vec256<int64_t>& b, const op_t& op) {
427  int64_t a0 = _mm256_extract_epi64(a, 0);
428  int64_t a1 = _mm256_extract_epi64(a, 1);
429  int64_t a2 = _mm256_extract_epi64(a, 2);
430  int64_t a3 = _mm256_extract_epi64(a, 3);
431 
432  int64_t b0 = _mm256_extract_epi64(b, 0);
433  int64_t b1 = _mm256_extract_epi64(b, 1);
434  int64_t b2 = _mm256_extract_epi64(b, 2);
435  int64_t b3 = _mm256_extract_epi64(b, 3);
436 
437  int64_t c0 = op(a0, b0);
438  int64_t c1 = op(a1, b1);
439  int64_t c2 = op(a2, b2);
440  int64_t c3 = op(a3, b3);
441 
442  return _mm256_set_epi64x(c3, c2, c1, c0);
443 }
444 
445 // AVX2 has no intrinsic for int64_t multiply so it needs to be emulated
446 // This could be implemented more efficiently using epi32 instructions
447 // This is also technically avx compatible, but then we'll need AVX
448 // code for add as well.
449 template <>
450 Vec256<int64_t> inline operator*(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
451  return emulate(a, b, [](int64_t a_point, int64_t b_point){return a_point * b_point;});
452 }
453 
454 template <>
455 Vec256<int32_t> inline operator*(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
456  return _mm256_mullo_epi32(a, b);
457 }
458 
459 template <>
460 Vec256<int16_t> inline operator*(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
461  return _mm256_mullo_epi16(a, b);
462 }
463 
464 template <>
465 Vec256<int64_t> inline minimum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
466  return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::min(a_point, b_point);});
467 }
468 
469 template <>
470 Vec256<int32_t> inline minimum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
471  return _mm256_min_epi32(a, b);
472 }
473 
474 template <>
475 Vec256<int16_t> inline minimum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
476  return _mm256_min_epi16(a, b);
477 }
478 
479 template <>
480 Vec256<int64_t> inline maximum(const Vec256<int64_t>& a, const Vec256<int64_t>& b) {
481  return emulate(a, b, [](int64_t a_point, int64_t b_point) {return std::max(a_point, b_point);});
482 }
483 
484 template <>
485 Vec256<int32_t> inline maximum(const Vec256<int32_t>& a, const Vec256<int32_t>& b) {
486  return _mm256_max_epi32(a, b);
487 }
488 
489 template <>
490 Vec256<int16_t> inline maximum(const Vec256<int16_t>& a, const Vec256<int16_t>& b) {
491  return _mm256_max_epi16(a, b);
492 }
493 
494 template <typename T>
495 Vec256<T> inline intdiv_256(const Vec256<T>& a, const Vec256<T>& b) {
496  T values_a[Vec256<T>::size()];
497  T values_b[Vec256<T>::size()];
498  a.store(values_a);
499  b.store(values_b);
500  for (int i = 0; i != Vec256<T>::size(); i++) {
501  values_a[i] /= values_b[i];
502  }
503  return Vec256<T>::loadu(values_a);
504 }
505 
506 #define DEFINE_INTEGER_BINARY_OP(op, func) \
507 template <> \
508 Vec256<int64_t> inline operator op(const Vec256<int64_t>& a, const Vec256<int64_t>& b) { \
509  return func(a, b); \
510 } \
511 template <> \
512 Vec256<int32_t> inline operator op(const Vec256<int32_t>& a, const Vec256<int32_t>& b) { \
513  return func(a, b); \
514 } \
515 template <> \
516 Vec256<int16_t> inline operator op(const Vec256<int16_t>& a, const Vec256<int16_t>& b) { \
517  return func(a, b); \
518 }
519 
520 DEFINE_INTEGER_BINARY_OP(/, intdiv_256)
521 DEFINE_INTEGER_BINARY_OP(&, _mm256_and_si256)
522 DEFINE_INTEGER_BINARY_OP(|, _mm256_or_si256)
523 DEFINE_INTEGER_BINARY_OP(^, _mm256_xor_si256)
524 
525 #undef DEFINE_INTEGER_BINARY_OP
526 
527 #endif
528 
529 }}}
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Definition: Half-inl.h:56
Flush-To-Zero and Denormals-Are-Zero mode.