3 #include <ATen/cpu/vec256/intrinsics.h> 5 #include <ATen/cpu/vec256/vec256_base.h> 6 #include <ATen/cpu/vec256/vec256_float.h> 7 #include <ATen/cpu/vec256/vec256_double.h> 8 #include <ATen/cpu/vec256/vec256_int.h> 31 std::ostream& operator<<(std::ostream& stream, const Vec256<T>& vec) {
32 T buf[Vec256<T>::size()];
35 for (
int i = 0; i != Vec256<T>::size(); i++) {
46 #if defined(__AVX__) && !defined(_MSC_VER) 51 Vec256<float> cast<float, double>(
const Vec256<double>& src) {
52 return _mm256_castpd_ps(src);
56 Vec256<double> cast<double, float>(
const Vec256<float>& src) {
57 return _mm256_castps_pd(src);
64 #define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \ 66 Vec256<int_t> cast<int_t, float_t>(const Vec256<float_t>& src) { \ 67 return _mm256_castp ## float_ch ## _si256(src); \ 70 Vec256<float_t> cast<float_t, int_t>(const Vec256<int_t>& src) { \ 71 return _mm256_castsi256_p ## float_ch (src); \ 74 DEFINE_FLOAT_INT_CAST(int64_t,
double, d)
75 DEFINE_FLOAT_INT_CAST(int32_t,
double, d)
76 DEFINE_FLOAT_INT_CAST(int16_t,
double, d)
77 DEFINE_FLOAT_INT_CAST(int64_t,
float, s)
78 DEFINE_FLOAT_INT_CAST(int32_t,
float, s)
79 DEFINE_FLOAT_INT_CAST(int16_t,
float, s)
81 #undef DEFINE_FLOAT_INT_CAST 85 template<
int64_t scale = 1>
86 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<double>>
87 inline gather(
const double* base_addr,
const Vec256<int64_t>& vindex) {
88 return _mm256_i64gather_pd(base_addr, vindex, scale);
91 template<
int64_t scale = 1>
92 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<float>>
93 inline gather(
const float* base_addr,
const Vec256<int32_t>& vindex) {
94 return _mm256_i32gather_ps(base_addr, vindex, scale);
99 template<
int64_t scale = 1>
100 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<double>>
101 inline mask_gather(
const Vec256<double>& src,
const double* base_addr,
102 const Vec256<int64_t>& vindex,
const Vec256<double>& mask) {
103 return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
106 template<
int64_t scale = 1>
107 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<float>>
108 inline mask_gather(
const Vec256<float>& src,
const float* base_addr,
109 const Vec256<int32_t>& vindex,
const Vec256<float>& mask) {
110 return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
119 inline convert_to_int_of_same_size<double>(
const Vec256<double> &src) {
120 auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
121 return _mm256_sub_epi64(
122 _mm256_castpd_si256(x),
123 _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
129 inline convert_to_int_of_same_size<float>(
const Vec256<float> &src) {
130 return _mm256_cvttps_epi32(src);
136 std::pair<Vec256<double>, Vec256<double>>
137 inline interleave2<double>(
const Vec256<double>& a,
const Vec256<double>& b) {
145 static constexpr
int swap_ctrl_a = 0 | (2 << 4);
146 static constexpr
int swap_ctrl_b = 1 | (3 << 4);
147 auto a_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_a);
148 auto b_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_b);
153 static constexpr
int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6);
154 return std::make_pair(_mm256_permute4x64_pd(a_swapped, group_ctrl),
155 _mm256_permute4x64_pd(b_swapped, group_ctrl));
159 std::pair<Vec256<float>, Vec256<float>>
160 inline interleave2<float>(
const Vec256<float>& a,
const Vec256<float>& b) {
169 static constexpr
int swap_ctrl_a = 0 | (2 << 4);
170 static constexpr
int swap_ctrl_b = 1 | (3 << 4);
171 auto a_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_a);
172 auto b_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_b);
177 const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
178 return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
179 _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
185 std::pair<Vec256<double>, Vec256<double>>
186 inline deinterleave2<double>(
const Vec256<double>& a,
const Vec256<double>& b) {
194 static constexpr
int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6);
195 auto a_grouped = _mm256_permute4x64_pd(a, group_ctrl);
196 auto b_grouped = _mm256_permute4x64_pd(b, group_ctrl);
201 static constexpr
int swap_ctrl_a = 0 | (2 << 4);
202 static constexpr
int swap_ctrl_b = 1 | (3 << 4);
203 return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_a),
204 _mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_b));
208 std::pair<Vec256<float>, Vec256<float>>
209 inline deinterleave2<float>(
const Vec256<float>& a,
const Vec256<float>& b) {
218 const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
219 auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
220 auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
225 static constexpr
int swap_ctrl_a = 0 | (2 << 4);
226 static constexpr
int swap_ctrl_b = 1 | (3 << 4);
227 return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_a),
228 _mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_b));
231 #endif // defined(__AVX2__) 233 #endif // defined(__AVX__) && !defined(_MSC_VER)
Flush-To-Zero and Denormals-Are-Zero mode.