9 #include <ATen/Utils.h> 10 #include <ATen/native/Copy.h> 11 #include <ATen/NumericUtils.h> 12 #include <c10/util/C++17.h> 15 #define __at_align32__ __attribute__((aligned(32))) 17 #define __at_align32__ __declspec(align(32)) 19 #define __at_align32__ 27 template<
size_t n>
struct int_of_size;
29 #define DEFINE_INT_OF_SIZE(int_t) \ 30 template<> struct int_of_size<sizeof(int_t)> { using type = int_t; } 32 DEFINE_INT_OF_SIZE(int64_t);
33 DEFINE_INT_OF_SIZE(int32_t);
34 DEFINE_INT_OF_SIZE(int16_t);
35 DEFINE_INT_OF_SIZE(int8_t);
37 #undef DEFINE_INT_OF_SIZE 40 using int_same_size_t =
typename int_of_size<sizeof(T)>::type;
48 T values[32 /
sizeof(
T)] = {0};
82 static constexpr
int size() {
83 return 32 /
sizeof(
T);
87 for (
int i = 0; i != size(); i++) {
91 template<
typename... Args,
92 typename = c10::guts::enable_if_t<(
sizeof...(Args) == size())>>
93 Vec256(Args... vals) {
96 template <
int64_t mask_>
97 static Vec256<T> blend(
const Vec256<T>& a,
const Vec256<T>& b) {
100 for (int64_t i = 0; i < size(); i++) {
110 static Vec256<T> blendv(
const Vec256<T>& a,
const Vec256<T>& b,
111 const Vec256<T>& mask) {
113 int_same_size_t<T> buffer[size()];
115 for (int64_t i = 0; i < size(); i++) {
116 if (buffer[i] & 0x01)
125 static Vec256<T> arange(
T base = static_cast<T>(0),
T step = static_cast<T>(1)) {
127 for (int64_t i = 0; i < size(); i++) {
128 vec.values[i] = base + i * step;
132 static Vec256<T>
set(
const Vec256<T>& a,
const Vec256<T>& b, int64_t count = size()) {
134 for (int64_t i = 0; i < size(); i++) {
143 static Vec256<T> loadu(
const void* ptr) {
145 std::memcpy(vec.values, ptr, 32);
148 static Vec256<T> loadu(
const void* ptr, int64_t count) {
150 std::memcpy(vec.values, ptr, count *
sizeof(
T));
153 void store(
void* ptr,
int count = size())
const {
154 std::memcpy(ptr, values, count *
sizeof(
T));
156 const T& operator[](
int idx)
const {
159 T& operator[](
int idx) {
162 Vec256<T> map(
T (*f)(
T))
const {
164 for (int64_t i = 0; i != size(); i++) {
165 ret[i] = f(values[i]);
169 Vec256<T> abs()
const {
171 for (int64_t i = 0; i < size(); i++) {
172 ret[i] = values[i] < 0 ? -values[i] : values[i];
176 Vec256<T> acos()
const {
177 return map(std::acos);
179 Vec256<T> asin()
const {
180 return map(std::asin);
182 Vec256<T> atan()
const {
183 return map(std::atan);
185 Vec256<T> erf()
const {
186 return map(std::erf);
188 Vec256<T> erfc()
const {
189 return map(std::erfc);
191 Vec256<T> exp()
const {
192 return map(std::exp);
194 Vec256<T> expm1()
const {
195 return map(std::expm1);
197 Vec256<T> log()
const {
198 return map(std::log);
200 Vec256<T> log10()
const {
201 return map(std::log10);
203 Vec256<T> log1p()
const {
204 return map(std::log1p);
206 Vec256<T> log2()
const {
207 return map(std::log2);
209 Vec256<T> ceil()
const {
210 return map(std::ceil);
212 Vec256<T> cos()
const {
213 return map(std::cos);
215 Vec256<T> cosh()
const {
216 return map(std::cosh);
218 Vec256<T> floor()
const {
219 return map(std::floor);
221 Vec256<T> neg()
const {
222 return map([](
T x) {
return -x; });
224 Vec256<T> round()
const {
225 return map(std::nearbyint);
227 Vec256<T> sin()
const {
228 return map(std::sin);
230 Vec256<T> sinh()
const {
231 return map(std::sinh);
233 Vec256<T> tan()
const {
234 return map(std::tan);
236 Vec256<T> tanh()
const {
237 return map(std::tanh);
239 Vec256<T> trunc()
const {
240 return map(std::trunc);
242 Vec256<T> sqrt()
const {
243 return map(std::sqrt);
245 Vec256<T> reciprocal()
const {
246 return map([](
T x) {
return (
T)(1) / x; });
248 Vec256<T> rsqrt()
const {
249 return map([](
T x) {
return 1 / std::sqrt(x); });
251 Vec256<T> pow(
const Vec256<T> &exp)
const {
253 for (int64_t i = 0; i < size(); i++) {
254 ret[i] = std::pow(values[i], exp[i]);
258 #define DEFINE_COMP(binary_pred) \ 259 Vec256<T> operator binary_pred(const Vec256<T> &other) const { \ 261 for (int64_t i = 0; i != size(); i++) { \ 262 if (values[i] binary_pred other.values[i]) { \ 263 std::memset(static_cast<void*>(vec.values + i), 0xFF, sizeof(T)); \ 265 std::memset(static_cast<void*>(vec.values + i), 0, sizeof(T)); \ 280 template <
class T> Vec256<T>
inline operator+(
const Vec256<T> &a,
const Vec256<T> &b) {
281 Vec256<T> c = Vec256<T>();
282 for (
int i = 0; i != Vec256<T>::size(); i++) {
288 template <
class T> Vec256<T>
inline operator-(
const Vec256<T> &a,
const Vec256<T> &b) {
289 Vec256<T> c = Vec256<T>();
290 for (
int i = 0; i != Vec256<T>::size(); i++) {
296 template <
class T> Vec256<T>
inline operator*(
const Vec256<T> &a,
const Vec256<T> &b) {
297 Vec256<T> c = Vec256<T>();
298 for (
int i = 0; i != Vec256<T>::size(); i++) {
304 template <
class T> Vec256<T>
inline operator/(
const Vec256<T> &a,
const Vec256<T> &b) __ubsan_ignore_float_divide_by_zero__ {
305 Vec256<T> c = Vec256<T>();
306 for (
int i = 0; i != Vec256<T>::size(); i++) {
312 template <
class T> Vec256<T>
inline operator||(
313 const Vec256<T> &a,
const Vec256<T> &b) {
314 Vec256<T> c = Vec256<T>();
315 for (
int i = 0; i != Vec256<T>::size(); i++) {
323 template <
class T> Vec256<T>
inline maximum(
const Vec256<T> &a,
const Vec256<T> &b) {
324 Vec256<T> c = Vec256<T>();
325 for (
int i = 0; i != Vec256<T>::size(); i++) {
326 c[i] = (a[i] > b[i]) ? a[i] : b[i];
337 template <
typename T>
338 inline T maximum(
const T& a,
const T& b) {
339 T c = (a > b) ? a : b;
348 template <
class T> Vec256<T>
inline minimum(
const Vec256<T> &a,
const Vec256<T> &b) {
349 Vec256<T> c = Vec256<T>();
350 for (
int i = 0; i != Vec256<T>::size(); i++) {
351 c[i] = (a[i] < b[i]) ? a[i] : b[i];
362 template <
typename T>
363 inline T minimum(
const T& a,
const T& b) {
364 T c = (a < b) ? a : b;
372 #define DEFINE_BITWISE_OP(op) \ 374 Vec256<T> inline operator op(const Vec256<T> &a, const Vec256<T> &b) { \ 375 using iT = int_same_size_t<T>; \ 376 iT buffer[Vec256<T>::size()]; \ 377 for (int64_t i = 0; i != Vec256<T>::size(); i++) { \ 380 iT *i_a_ptr = reinterpret_cast<iT*>(&a_val); \ 381 iT *i_b_ptr = reinterpret_cast<iT*>(&b_val); \ 382 buffer[i] = *i_a_ptr op *i_b_ptr; \ 384 return Vec256<T>::loadu(buffer); \ 389 #undef DEFINE_BITWISE_OP 391 template <
typename T>
392 inline T fmadd(
const T& a,
const T& b,
const T& c) {
396 template <
int64_t scale = 1,
typename T =
void>
397 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
398 inline gather(
T const* base_addr,
const Vec256<int_same_size_t<T>>& vindex) {
399 static constexpr
int size = Vec256<T>::size();
400 int_same_size_t<T> index_arr[size];
401 vindex.store(static_cast<void*>(index_arr));
403 for (int64_t i = 0; i < size; i++) {
404 buffer[i] = base_addr[index_arr[i] * scale /
sizeof(
T)];
406 return Vec256<T>::loadu(static_cast<void*>(buffer));
409 template <
int64_t scale = 1,
typename T =
void>
410 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<T>>
411 inline mask_gather(
const Vec256<T>& src,
T const* base_addr,
412 const Vec256<int_same_size_t<T>>& vindex, Vec256<T>& mask) {
413 static constexpr
int size = Vec256<T>::size();
415 int_same_size_t<T> mask_arr[size];
416 int_same_size_t<T> index_arr[size];
417 src.store(static_cast<void*>(src_arr));
418 mask.store(static_cast<void*>(mask_arr));
419 vindex.store(static_cast<void*>(index_arr));
421 for (int64_t i = 0; i < size; i++) {
422 if (mask_arr[i] & 0x01) {
423 buffer[i] = base_addr[index_arr[i] * scale /
sizeof(
T)];
425 buffer[i] = src_arr[i];
429 return Vec256<T>::loadu(static_cast<void*>(buffer));
438 template<
typename dst_t,
typename src_t>
440 static inline Vec256<dst_t> apply(
const Vec256<src_t>& src) {
441 src_t src_arr[Vec256<src_t>::size()];
442 src.store(static_cast<void*>(src_arr));
443 return Vec256<dst_t>::loadu(static_cast<const void*>(src_arr));
447 template<
typename scalar_t>
448 struct CastImpl<scalar_t, scalar_t> {
449 static inline Vec256<scalar_t> apply(
const Vec256<scalar_t>& src) {
454 template<
typename dst_t,
typename src_t>
455 Vec256<dst_t> cast(
const Vec256<src_t>& src) {
456 return CastImpl<dst_t, src_t>::apply(src);
459 template <
typename T>
460 inline Vec256<int_same_size_t<T>> convert_to_int_of_same_size(
const Vec256<T>& src) {
461 static constexpr
int size = Vec256<T>::size();
463 src.store(static_cast<void*>(src_arr));
464 int_same_size_t<T> buffer[size];
465 for (int64_t i = 0; i < size; i++) {
466 buffer[i] =
static_cast<int_same_size_t<T>
>(src_arr[i]);
468 return Vec256<int_same_size_t<T>>::loadu(static_cast<void*>(buffer));
475 template <
typename T>
476 inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
477 deinterleave2(
const Vec256<T>& a,
const Vec256<T>& b) {
478 static constexpr
int size = Vec256<T>::size();
479 static constexpr
int half_size = size / 2;
484 a.store(static_cast<void*>(a_arr));
485 b.store(static_cast<void*>(b_arr));
486 for (int64_t i = 0; i < half_size; i++) {
487 buffer1[i] = a_arr[i * 2];
488 buffer1[half_size + i] = b_arr[i * 2];
489 buffer2[i] = a_arr[i * 2 + 1];
490 buffer2[half_size + i] = b_arr[i * 2 + 1];
492 return std::make_pair(Vec256<T>::loadu(static_cast<void*>(buffer1)),
493 Vec256<T>::loadu(static_cast<void*>(buffer2)));
501 template <
typename T>
502 inline c10::guts::enable_if_t<Vec256<T>::size() % 2 == 0, std::pair<Vec256<T>, Vec256<T>>>
503 interleave2(
const Vec256<T>& a,
const Vec256<T>& b) {
504 static constexpr
int size = Vec256<T>::size();
505 static constexpr
int half_size = size / 2;
510 a.store(static_cast<void*>(a_arr));
511 b.store(static_cast<void*>(b_arr));
512 for (int64_t i = 0; i < half_size; i++) {
513 buffer1[i * 2] = a_arr[i];
514 buffer1[i * 2 + 1] = b_arr[i];
515 buffer2[i * 2] = a_arr[half_size + i];
516 buffer2[i * 2 + 1] = b_arr[half_size + i];
518 return std::make_pair(Vec256<T>::loadu(static_cast<void*>(buffer1)),
519 Vec256<T>::loadu(static_cast<void*>(buffer2)));
522 template <
typename src_T,
typename dst_T>
523 void convert(
const src_T *src, dst_T *dst, int64_t n) {
527 for (int64_t i = 0; i < n; i++) {
528 *dst =
static_cast<dst_T
>(
529 static_cast<at::native::inter_copy_type_t<dst_T>
>(*src));
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Flush-To-Zero and Denormals-Are-Zero mode.