3 #include <ATen/cpu/vec256/intrinsics.h> 4 #include <ATen/cpu/vec256/vec256_base.h> 5 #if defined(__AVX__) && !defined(_MSC_VER) 14 #if defined(__AVX__) && !defined(_MSC_VER) 16 template <>
class Vec256<double> {
20 static constexpr
int size() {
24 Vec256(__m256d v) : values(v) {}
26 values = _mm256_set1_pd(val);
28 Vec256(
double val1,
double val2,
double val3,
double val4) {
29 values = _mm256_setr_pd(val1, val2, val3, val4);
31 operator __m256d()
const {
34 template <
int64_t mask>
35 static Vec256<double> blend(
const Vec256<double>& a,
const Vec256<double>& b) {
36 return _mm256_blend_pd(a.values, b.values, mask);
38 static Vec256<double> blendv(
const Vec256<double>& a,
const Vec256<double>& b,
39 const Vec256<double>& mask) {
40 return _mm256_blendv_pd(a.values, b.values, mask.values);
42 static Vec256<double> arange(
double base = 0.,
double step = 1.) {
43 return Vec256<double>(base, base + step, base + 2 * step, base + 3 * step);
45 static Vec256<double>
set(
const Vec256<double>& a,
const Vec256<double>& b,
46 int64_t count = size()) {
51 return blend<1>(a, b);
53 return blend<3>(a, b);
55 return blend<7>(a, b);
59 static Vec256<double> loadu(
const void* ptr, int64_t count = size()) {
61 return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
63 __at_align32__
double tmp_values[size()];
66 reinterpret_cast<const double*>(ptr),
67 count *
sizeof(
double));
68 return _mm256_load_pd(tmp_values);
70 void store(
void* ptr,
int count = size())
const {
71 if (count == size()) {
72 _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
73 }
else if (count > 0) {
74 double tmp_values[size()];
75 _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
76 std::memcpy(ptr, tmp_values, count *
sizeof(
double));
79 const double& operator[](
int idx)
const =
delete;
80 double& operator[](
int idx) =
delete;
81 Vec256<double> map(
double (*f)(
double))
const {
82 __at_align32__
double tmp[4];
84 for (int64_t i = 0; i < 4; i++) {
89 Vec256<double> abs()
const {
90 auto mask = _mm256_set1_pd(-0.f);
91 return _mm256_andnot_pd(mask, values);
93 Vec256<double> acos()
const {
94 return Vec256<double>(Sleef_acosd4_u10(values));
96 Vec256<double> asin()
const {
97 return Vec256<double>(Sleef_asind4_u10(values));
99 Vec256<double> atan()
const {
100 return Vec256<double>(Sleef_atand4_u10(values));
102 Vec256<double> erf()
const {
103 return Vec256<double>(Sleef_erfd4_u10(values));
105 Vec256<double> erfc()
const {
106 return Vec256<double>(Sleef_erfcd4_u15(values));
108 Vec256<double> exp()
const {
109 return Vec256<double>(Sleef_expd4_u10(values));
111 Vec256<double> expm1()
const {
112 return Vec256<double>(Sleef_expm1d4_u10(values));
114 Vec256<double> log()
const {
115 return Vec256<double>(Sleef_logd4_u10(values));
117 Vec256<double> log2()
const {
118 return Vec256<double>(Sleef_log2d4_u10(values));
120 Vec256<double> log10()
const {
121 return Vec256<double>(Sleef_log10d4_u10(values));
123 Vec256<double> log1p()
const {
124 return Vec256<double>(Sleef_log1pd4_u10(values));
126 Vec256<double> sin()
const {
127 return map(std::sin);
129 Vec256<double> sinh()
const {
130 return map(std::sinh);
132 Vec256<double> cos()
const {
133 return map(std::cos);
135 Vec256<double> cosh()
const {
136 return map(std::cos);
138 Vec256<double> ceil()
const {
139 return _mm256_ceil_pd(values);
141 Vec256<double> floor()
const {
142 return _mm256_floor_pd(values);
144 Vec256<double> neg()
const {
145 return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
147 Vec256<double> round()
const {
148 return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
150 Vec256<double> tan()
const {
151 return map(std::tan);
153 Vec256<double> tanh()
const {
154 return Vec256<double>(Sleef_tanhd4_u10(values));
156 Vec256<double> trunc()
const {
157 return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
159 Vec256<double> sqrt()
const {
160 return _mm256_sqrt_pd(values);
162 Vec256<double> reciprocal()
const {
163 return _mm256_div_pd(_mm256_set1_pd(1), values);
165 Vec256<double> rsqrt()
const {
166 return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
168 Vec256<double> pow(
const Vec256<double> &b)
const {
169 return Vec256<double>(Sleef_powd4_u10(values, b));
174 Vec256<double> operator==(
const Vec256<double>& other)
const {
175 return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
178 Vec256<double> operator!=(
const Vec256<double>& other)
const {
179 return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
182 Vec256<double> operator<(const Vec256<double>& other)
const {
183 return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
186 Vec256<double> operator<=(const Vec256<double>& other)
const {
187 return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
190 Vec256<double> operator>(
const Vec256<double>& other)
const {
191 return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
194 Vec256<double> operator>=(
const Vec256<double>& other)
const {
195 return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
200 Vec256<double>
inline operator+(
const Vec256<double>& a,
const Vec256<double>& b) {
201 return _mm256_add_pd(a, b);
205 Vec256<double>
inline operator-(
const Vec256<double>& a,
const Vec256<double>& b) {
206 return _mm256_sub_pd(a, b);
210 Vec256<double>
inline operator*(
const Vec256<double>& a,
const Vec256<double>& b) {
211 return _mm256_mul_pd(a, b);
215 Vec256<double>
inline operator/(
const Vec256<double>& a,
const Vec256<double>& b) {
216 return _mm256_div_pd(a, b);
222 Vec256<double>
inline maximum(
const Vec256<double>& a,
const Vec256<double>& b) {
223 Vec256<double> max = _mm256_max_pd(a, b);
224 Vec256<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
226 return _mm256_or_pd(max, isnan);
232 Vec256<double>
inline minimum(
const Vec256<double>& a,
const Vec256<double>& b) {
233 Vec256<double> min = _mm256_min_pd(a, b);
234 Vec256<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
236 return _mm256_or_pd(min, isnan);
240 Vec256<double>
inline operator&(
const Vec256<double>& a,
const Vec256<double>& b) {
241 return _mm256_and_pd(a, b);
245 Vec256<double>
inline operator|(
const Vec256<double>& a,
const Vec256<double>& b) {
246 return _mm256_or_pd(a, b);
250 Vec256<double>
inline operator^(
const Vec256<double>& a,
const Vec256<double>& b) {
251 return _mm256_xor_pd(a, b);
255 void convert(
const double* src,
double* dst, int64_t n) {
258 for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
259 _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
269 Vec256<double>
inline fmadd(
const Vec256<double>& a,
const Vec256<double>& b,
const Vec256<double>& c) {
270 return _mm256_fmadd_pd(a, b, c);
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Flush-To-Zero and Denormals-Are-Zero mode.