3 #include <ATen/cpu/vec256/intrinsics.h> 4 #include <ATen/cpu/vec256/vec256_base.h> 16 static inline __m256i invert(
const __m256i& v) {
17 const auto ones = _mm256_set1_epi64x(-1);
18 return _mm256_xor_si256(ones, v);
22 Vec256i(__m256i v) : values(v) {}
23 operator __m256i()
const {
29 struct Vec256<int64_t> :
public Vec256i {
30 static constexpr
int size() {
33 using Vec256i::Vec256i;
35 Vec256(int64_t v) { values = _mm256_set1_epi64x(v); }
36 Vec256(int64_t val1, int64_t val2, int64_t val3, int64_t val4) {
37 values = _mm256_setr_epi64x(val1, val2, val3, val4);
39 template <
int64_t mask>
40 static Vec256<int64_t> blend(Vec256<int64_t> a, Vec256<int64_t> b) {
41 __at_align32__ int64_t tmp_values[size()];
44 tmp_values[0] = _mm256_extract_epi64(b.values, 0);
46 tmp_values[1] = _mm256_extract_epi64(b.values, 1);
48 tmp_values[2] = _mm256_extract_epi64(b.values, 2);
50 tmp_values[3] = _mm256_extract_epi64(b.values, 3);
51 return loadu(tmp_values);
53 static Vec256<int64_t> blendv(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b,
54 const Vec256<int64_t>& mask) {
55 return _mm256_blendv_epi8(a.values, b.values, mask.values);
57 static Vec256<int64_t> arange(int64_t base = 0, int64_t step = 1) {
58 return Vec256<int64_t>(base, base + step, base + 2 * step, base + 3 * step);
60 static Vec256<int64_t>
61 set(Vec256<int64_t> a, Vec256<int64_t> b, int64_t count = size()) {
66 return blend<1>(a, b);
68 return blend<3>(a, b);
70 return blend<7>(a, b);
74 static Vec256<int64_t> loadu(
const void* ptr) {
75 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
77 static Vec256<int64_t> loadu(
const void* ptr, int64_t count) {
78 __at_align32__ int64_t tmp_values[size()];
79 std::memcpy(tmp_values, ptr, count *
sizeof(int64_t));
80 return loadu(tmp_values);
82 void store(
void* ptr,
int count = size())
const {
83 if (count == size()) {
84 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
85 }
else if (count > 0) {
86 __at_align32__ int64_t tmp_values[size()];
87 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
88 std::memcpy(ptr, tmp_values, count *
sizeof(int64_t));
91 const int64_t& operator[](
int idx)
const =
delete;
92 int64_t& operator[](
int idx) =
delete;
93 Vec256<int64_t> abs()
const {
94 auto zero = _mm256_set1_epi64x(0);
95 auto is_larger = _mm256_cmpgt_epi64(zero, values);
96 auto inverse = _mm256_xor_si256(values, is_larger);
97 return _mm256_sub_epi64(inverse, is_larger);
99 Vec256<int64_t> operator==(
const Vec256<int64_t>& other)
const {
100 return _mm256_cmpeq_epi64(values, other.values);
102 Vec256<int64_t> operator!=(
const Vec256<int64_t>& other)
const {
103 return invert(_mm256_cmpeq_epi64(values, other.values));
105 Vec256<int64_t> operator<(const Vec256<int64_t>& other)
const {
106 return _mm256_cmpgt_epi64(other.values, values);
108 Vec256<int64_t> operator<=(const Vec256<int64_t>& other)
const {
109 return invert(_mm256_cmpgt_epi64(values, other.values));
111 Vec256<int64_t> operator>(
const Vec256<int64_t>& other)
const {
112 return _mm256_cmpgt_epi64(values, other.values);
114 Vec256<int64_t> operator>=(
const Vec256<int64_t>& other)
const {
115 return invert(_mm256_cmpgt_epi64(other.values, values));
120 struct Vec256<int32_t> :
public Vec256i {
121 static constexpr
int size() {
124 using Vec256i::Vec256i;
126 Vec256(int32_t v) { values = _mm256_set1_epi32(v); }
127 Vec256(int32_t val1, int32_t val2, int32_t val3, int32_t val4,
128 int32_t val5, int32_t val6, int32_t val7, int32_t val8) {
129 values = _mm256_setr_epi32(val1, val2, val3, val4, val5, val6, val7, val8);
131 template <
int64_t mask>
132 static Vec256<int32_t> blend(Vec256<int32_t> a, Vec256<int32_t> b) {
133 return _mm256_blend_epi32(a, b, mask);
135 static Vec256<int32_t> blendv(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b,
136 const Vec256<int32_t>& mask) {
137 return _mm256_blendv_epi8(a.values, b.values, mask.values);
139 static Vec256<int32_t> arange(int32_t base = 0, int32_t step = 1) {
140 return Vec256<int32_t>(
141 base, base + step, base + 2 * step, base + 3 * step,
142 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
144 static Vec256<int32_t>
145 set(Vec256<int32_t> a, Vec256<int32_t> b, int32_t count = size()) {
150 return blend<1>(a, b);
152 return blend<3>(a, b);
154 return blend<7>(a, b);
156 return blend<15>(a, b);
158 return blend<31>(a, b);
160 return blend<63>(a, b);
162 return blend<127>(a, b);
166 static Vec256<int32_t> loadu(
const void* ptr) {
167 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
169 static Vec256<int32_t> loadu(
const void* ptr, int32_t count) {
170 __at_align32__ int32_t tmp_values[size()];
171 std::memcpy(tmp_values, ptr, count *
sizeof(int32_t));
172 return loadu(tmp_values);
174 void store(
void* ptr,
int count = size())
const {
175 if (count == size()) {
176 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
177 }
else if (count > 0) {
178 __at_align32__ int32_t tmp_values[size()];
179 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
180 std::memcpy(ptr, tmp_values, count *
sizeof(int32_t));
183 const int32_t& operator[](
int idx)
const =
delete;
184 int32_t& operator[](
int idx) =
delete;
185 Vec256<int32_t> abs()
const {
186 return _mm256_abs_epi32(values);
188 Vec256<int32_t> operator==(
const Vec256<int32_t>& other)
const {
189 return _mm256_cmpeq_epi32(values, other.values);
191 Vec256<int32_t> operator!=(
const Vec256<int32_t>& other)
const {
192 return invert(_mm256_cmpeq_epi32(values, other.values));
194 Vec256<int32_t> operator<(const Vec256<int32_t>& other)
const {
195 return _mm256_cmpgt_epi32(other.values, values);
197 Vec256<int32_t> operator<=(const Vec256<int32_t>& other)
const {
198 return invert(_mm256_cmpgt_epi32(values, other.values));
200 Vec256<int32_t> operator>(
const Vec256<int32_t>& other)
const {
201 return _mm256_cmpgt_epi32(values, other.values);
203 Vec256<int32_t> operator>=(
const Vec256<int32_t>& other)
const {
204 return invert(_mm256_cmpgt_epi32(other.values, values));
209 void convert(
const int32_t *src,
float *dst, int64_t n) {
215 for (i = 0; i <= (n - Vec256<int32_t>::size()); i += Vec256<int32_t>::size()) {
216 auto input_vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(src + i));
217 auto output_vec = _mm256_cvtepi32_ps(input_vec);
218 _mm256_storeu_ps(reinterpret_cast<float*>(dst + i), output_vec);
224 dst[i] =
static_cast<float>(src[i]);
229 void convert(
const int32_t *src,
double *dst, int64_t n) {
235 for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
236 auto input_128_vec = _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i));
237 auto output_vec = _mm256_cvtepi32_pd(input_128_vec);
238 _mm256_storeu_pd(reinterpret_cast<double*>(dst + i), output_vec);
244 dst[i] =
static_cast<double>(src[i]);
249 struct Vec256<int16_t> :
public Vec256i {
250 static constexpr
int size() {
253 using Vec256i::Vec256i;
255 Vec256(int16_t v) { values = _mm256_set1_epi16(v); }
256 Vec256(int16_t val1, int16_t val2, int16_t val3, int16_t val4,
257 int16_t val5, int16_t val6, int16_t val7, int16_t val8,
258 int16_t val9, int16_t val10, int16_t val11, int16_t val12,
259 int16_t val13, int16_t val14, int16_t val15, int16_t val16) {
260 values = _mm256_setr_epi16(val1, val2, val3, val4, val5, val6, val7, val8,
261 val9, val10, val11, val12, val13, val14, val15, val16);
263 template <
int64_t mask>
264 static Vec256<int16_t> blend(Vec256<int16_t> a, Vec256<int16_t> b) {
265 __at_align32__ int16_t tmp_values[size()];
268 tmp_values[0] = _mm256_extract_epi16(b.values, 0);
270 tmp_values[1] = _mm256_extract_epi16(b.values, 1);
272 tmp_values[2] = _mm256_extract_epi16(b.values, 2);
274 tmp_values[3] = _mm256_extract_epi16(b.values, 3);
276 tmp_values[4] = _mm256_extract_epi16(b.values, 4);
278 tmp_values[5] = _mm256_extract_epi16(b.values, 5);
280 tmp_values[6] = _mm256_extract_epi16(b.values, 6);
282 tmp_values[7] = _mm256_extract_epi16(b.values, 7);
284 tmp_values[8] = _mm256_extract_epi16(b.values, 8);
286 tmp_values[9] = _mm256_extract_epi16(b.values, 9);
288 tmp_values[10] = _mm256_extract_epi16(b.values, 10);
290 tmp_values[11] = _mm256_extract_epi16(b.values, 11);
292 tmp_values[12] = _mm256_extract_epi16(b.values, 12);
294 tmp_values[13] = _mm256_extract_epi16(b.values, 13);
296 tmp_values[14] = _mm256_extract_epi16(b.values, 14);
298 tmp_values[15] = _mm256_extract_epi16(b.values, 15);
299 return loadu(tmp_values);
301 static Vec256<int16_t> blendv(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b,
302 const Vec256<int16_t>& mask) {
303 return _mm256_blendv_epi8(a.values, b.values, mask.values);
305 static Vec256<int16_t> arange(int16_t base = 0, int16_t step = 1) {
306 return Vec256<int16_t>(
307 base, base + step, base + 2 * step, base + 3 * step,
308 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step,
309 base + 8 * step, base + 9 * step, base + 10 * step, base + 11 * step,
310 base + 12 * step, base + 13 * step, base + 14 * step, base + 15 * step);
312 static Vec256<int16_t>
313 set(Vec256<int16_t> a, Vec256<int16_t> b, int16_t count = size()) {
318 return blend<1>(a, b);
320 return blend<3>(a, b);
322 return blend<7>(a, b);
324 return blend<15>(a, b);
326 return blend<31>(a, b);
328 return blend<63>(a, b);
330 return blend<127>(a, b);
332 return blend<255>(a, b);
334 return blend<511>(a, b);
336 return blend<1023>(a, b);
338 return blend<2047>(a, b);
340 return blend<4095>(a, b);
342 return blend<8191>(a, b);
344 return blend<16383>(a, b);
346 return blend<32767>(a, b);
350 static Vec256<int16_t> loadu(
const void* ptr) {
351 return _mm256_loadu_si256(reinterpret_cast<const __m256i*>(ptr));
353 static Vec256<int16_t> loadu(
const void* ptr, int16_t count) {
354 __at_align32__ int16_t tmp_values[size()];
355 std::memcpy(tmp_values, ptr, count *
sizeof(int16_t));
356 return loadu(tmp_values);
358 void store(
void* ptr,
int count = size())
const {
359 if (count == size()) {
360 _mm256_storeu_si256(reinterpret_cast<__m256i*>(ptr), values);
361 }
else if (count > 0) {
362 __at_align32__ int16_t tmp_values[size()];
363 _mm256_storeu_si256(reinterpret_cast<__m256i*>(tmp_values), values);
364 std::memcpy(ptr, tmp_values, count *
sizeof(int16_t));
367 const int16_t& operator[](
int idx)
const =
delete;
368 int16_t& operator[](
int idx) =
delete;
369 Vec256<int16_t> abs()
const {
370 return _mm256_abs_epi16(values);
372 Vec256<int16_t> operator==(
const Vec256<int16_t>& other)
const {
373 return _mm256_cmpeq_epi16(values, other.values);
375 Vec256<int16_t> operator!=(
const Vec256<int16_t>& other)
const {
376 return invert(_mm256_cmpeq_epi16(values, other.values));
378 Vec256<int16_t> operator<(const Vec256<int16_t>& other)
const {
379 return _mm256_cmpgt_epi16(other.values, values);
381 Vec256<int16_t> operator<=(const Vec256<int16_t>& other)
const {
382 return invert(_mm256_cmpgt_epi16(values, other.values));
384 Vec256<int16_t> operator>(
const Vec256<int16_t>& other)
const {
385 return _mm256_cmpgt_epi16(values, other.values);
387 Vec256<int16_t> operator>=(
const Vec256<int16_t>& other)
const {
388 return invert(_mm256_cmpgt_epi16(other.values, values));
393 Vec256<int64_t>
inline operator+(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
394 return _mm256_add_epi64(a, b);
398 Vec256<int32_t>
inline operator+(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
399 return _mm256_add_epi32(a, b);
403 Vec256<int16_t>
inline operator+(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
404 return _mm256_add_epi16(a, b);
408 Vec256<int64_t>
inline operator-(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
409 return _mm256_sub_epi64(a, b);
413 Vec256<int32_t>
inline operator-(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
414 return _mm256_sub_epi32(a, b);
418 Vec256<int16_t>
inline operator-(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
419 return _mm256_sub_epi16(a, b);
425 template <
typename op_t>
426 Vec256<int64_t>
inline emulate(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b,
const op_t& op) {
427 int64_t a0 = _mm256_extract_epi64(a, 0);
428 int64_t a1 = _mm256_extract_epi64(a, 1);
429 int64_t a2 = _mm256_extract_epi64(a, 2);
430 int64_t a3 = _mm256_extract_epi64(a, 3);
432 int64_t b0 = _mm256_extract_epi64(b, 0);
433 int64_t b1 = _mm256_extract_epi64(b, 1);
434 int64_t b2 = _mm256_extract_epi64(b, 2);
435 int64_t b3 = _mm256_extract_epi64(b, 3);
437 int64_t c0 = op(a0, b0);
438 int64_t c1 = op(a1, b1);
439 int64_t c2 = op(a2, b2);
440 int64_t c3 = op(a3, b3);
442 return _mm256_set_epi64x(c3, c2, c1, c0);
450 Vec256<int64_t>
inline operator*(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
451 return emulate(a, b, [](int64_t a_point, int64_t b_point){
return a_point * b_point;});
455 Vec256<int32_t>
inline operator*(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
456 return _mm256_mullo_epi32(a, b);
460 Vec256<int16_t>
inline operator*(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
461 return _mm256_mullo_epi16(a, b);
465 Vec256<int64_t>
inline minimum(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
466 return emulate(a, b, [](int64_t a_point, int64_t b_point) {
return std::min(a_point, b_point);});
470 Vec256<int32_t>
inline minimum(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
471 return _mm256_min_epi32(a, b);
475 Vec256<int16_t>
inline minimum(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
476 return _mm256_min_epi16(a, b);
480 Vec256<int64_t>
inline maximum(
const Vec256<int64_t>& a,
const Vec256<int64_t>& b) {
481 return emulate(a, b, [](int64_t a_point, int64_t b_point) {
return std::max(a_point, b_point);});
485 Vec256<int32_t>
inline maximum(
const Vec256<int32_t>& a,
const Vec256<int32_t>& b) {
486 return _mm256_max_epi32(a, b);
490 Vec256<int16_t>
inline maximum(
const Vec256<int16_t>& a,
const Vec256<int16_t>& b) {
491 return _mm256_max_epi16(a, b);
494 template <
typename T>
495 Vec256<T>
inline intdiv_256(
const Vec256<T>& a,
const Vec256<T>& b) {
496 T values_a[Vec256<T>::size()];
497 T values_b[Vec256<T>::size()];
500 for (
int i = 0; i != Vec256<T>::size(); i++) {
501 values_a[i] /= values_b[i];
503 return Vec256<T>::loadu(values_a);
506 #define DEFINE_INTEGER_BINARY_OP(op, func) \ 508 Vec256<int64_t> inline operator op(const Vec256<int64_t>& a, const Vec256<int64_t>& b) { \ 512 Vec256<int32_t> inline operator op(const Vec256<int32_t>& a, const Vec256<int32_t>& b) { \ 516 Vec256<int16_t> inline operator op(const Vec256<int16_t>& a, const Vec256<int16_t>& b) { \ 520 DEFINE_INTEGER_BINARY_OP(/, intdiv_256)
521 DEFINE_INTEGER_BINARY_OP(&, _mm256_and_si256)
522 DEFINE_INTEGER_BINARY_OP(|, _mm256_or_si256)
523 DEFINE_INTEGER_BINARY_OP(^, _mm256_xor_si256)
525 #undef DEFINE_INTEGER_BINARY_OP C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Flush-To-Zero and Denormals-Are-Zero mode.