3 #include <ATen/cpu/vec256/intrinsics.h> 4 #include <ATen/cpu/vec256/vec256_base.h> 5 #if defined(__AVX__) && !defined(_MSC_VER) 14 #if defined(__AVX__) && !defined(_MSC_VER) 16 template <>
class Vec256<float> {
20 static constexpr
int size() {
24 Vec256(__m256 v) : values(v) {}
26 values = _mm256_set1_ps(val);
28 Vec256(
float val1,
float val2,
float val3,
float val4,
29 float val5,
float val6,
float val7,
float val8) {
30 values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8);
32 operator __m256()
const {
35 template <
int64_t mask>
36 static Vec256<float> blend(
const Vec256<float>& a,
const Vec256<float>& b) {
37 return _mm256_blend_ps(a.values, b.values, mask);
39 static Vec256<float> blendv(
const Vec256<float>& a,
const Vec256<float>& b,
40 const Vec256<float>& mask) {
41 return _mm256_blendv_ps(a.values, b.values, mask.values);
43 static Vec256<float> arange(
float base = 0.f,
float step = 1.f) {
45 base, base + step, base + 2 * step, base + 3 * step,
46 base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
48 static Vec256<float>
set(
const Vec256<float>& a,
const Vec256<float>& b,
49 int64_t count = size()) {
54 return blend<1>(a, b);
56 return blend<3>(a, b);
58 return blend<7>(a, b);
60 return blend<15>(a, b);
62 return blend<31>(a, b);
64 return blend<63>(a, b);
66 return blend<127>(a, b);
70 static Vec256<float> loadu(
const void* ptr, int64_t count = size()) {
72 return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
73 __at_align32__
float tmp_values[size()];
75 tmp_values, reinterpret_cast<const float*>(ptr), count *
sizeof(
float));
76 return _mm256_loadu_ps(tmp_values);
78 void store(
void* ptr, int64_t count = size())
const {
79 if (count == size()) {
80 _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
81 }
else if (count > 0) {
82 float tmp_values[size()];
83 _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
84 std::memcpy(ptr, tmp_values, count *
sizeof(
float));
87 const float& operator[](
int idx)
const =
delete;
88 float& operator[](
int idx) =
delete;
89 Vec256<float> map(
float (*f)(
float))
const {
90 __at_align32__
float tmp[8];
92 for (int64_t i = 0; i < 8; i++) {
97 Vec256<float> abs()
const {
98 auto mask = _mm256_set1_ps(-0.f);
99 return _mm256_andnot_ps(mask, values);
101 Vec256<float> acos()
const {
102 return Vec256<float>(Sleef_acosf8_u10(values));
104 Vec256<float> asin()
const {
105 return Vec256<float>(Sleef_asinf8_u10(values));
107 Vec256<float> atan()
const {
108 return Vec256<float>(Sleef_atanf8_u10(values));
110 Vec256<float> erf()
const {
111 return Vec256<float>(Sleef_erff8_u10(values));
113 Vec256<float> erfc()
const {
114 return Vec256<float>(Sleef_erfcf8_u15(values));
116 Vec256<float> exp()
const {
117 return Vec256<float>(Sleef_expf8_u10(values));
119 Vec256<float> expm1()
const {
120 return Vec256<float>(Sleef_expm1f8_u10(values));
122 Vec256<float> log()
const {
123 return Vec256<float>(Sleef_logf8_u10(values));
125 Vec256<float> log2()
const {
126 return Vec256<float>(Sleef_log2f8_u10(values));
128 Vec256<float> log10()
const {
129 return Vec256<float>(Sleef_log10f8_u10(values));
131 Vec256<float> log1p()
const {
132 return Vec256<float>(Sleef_log1pf8_u10(values));
134 Vec256<float> sin()
const {
135 return map(std::sin);
137 Vec256<float> sinh()
const {
138 return map(std::sinh);
140 Vec256<float> cos()
const {
141 return map(std::cos);
143 Vec256<float> cosh()
const {
144 return map(std::cosh);
146 Vec256<float> ceil()
const {
147 return _mm256_ceil_ps(values);
149 Vec256<float> floor()
const {
150 return _mm256_floor_ps(values);
152 Vec256<float> neg()
const {
153 return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
155 Vec256<float> round()
const {
156 return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
158 Vec256<float> tan()
const {
159 return map(std::tan);
161 Vec256<float> tanh()
const {
162 return Vec256<float>(Sleef_tanhf8_u10(values));
164 Vec256<float> trunc()
const {
165 return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
167 Vec256<float> sqrt()
const {
168 return _mm256_sqrt_ps(values);
170 Vec256<float> reciprocal()
const {
171 return _mm256_div_ps(_mm256_set1_ps(1), values);
173 Vec256<float> rsqrt()
const {
174 return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
176 Vec256<float> pow(
const Vec256<float> &b)
const {
177 return Vec256<float>(Sleef_powf8_u10(values, b));
182 Vec256<float> operator==(
const Vec256<float>& other)
const {
183 return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
186 Vec256<float> operator!=(
const Vec256<float>& other)
const {
187 return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
190 Vec256<float> operator<(const Vec256<float>& other)
const {
191 return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ);
194 Vec256<float> operator<=(const Vec256<float>& other)
const {
195 return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ);
198 Vec256<float> operator>(
const Vec256<float>& other)
const {
199 return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ);
202 Vec256<float> operator>=(
const Vec256<float>& other)
const {
203 return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ);
208 Vec256<float>
inline operator+(
const Vec256<float>& a,
const Vec256<float>& b) {
209 return _mm256_add_ps(a, b);
213 Vec256<float>
inline operator-(
const Vec256<float>& a,
const Vec256<float>& b) {
214 return _mm256_sub_ps(a, b);
218 Vec256<float>
inline operator*(
const Vec256<float>& a,
const Vec256<float>& b) {
219 return _mm256_mul_ps(a, b);
223 Vec256<float>
inline operator/(
const Vec256<float>& a,
const Vec256<float>& b) {
224 return _mm256_div_ps(a, b);
230 Vec256<float>
inline maximum(
const Vec256<float>& a,
const Vec256<float>& b) {
231 Vec256<float> max = _mm256_max_ps(a, b);
232 Vec256<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
234 return _mm256_or_ps(max, isnan);
240 Vec256<float>
inline minimum(
const Vec256<float>& a,
const Vec256<float>& b) {
241 Vec256<float> min = _mm256_min_ps(a, b);
242 Vec256<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
244 return _mm256_or_ps(min, isnan);
248 Vec256<float>
inline operator&(
const Vec256<float>& a,
const Vec256<float>& b) {
249 return _mm256_and_ps(a, b);
253 Vec256<float>
inline operator|(
const Vec256<float>& a,
const Vec256<float>& b) {
254 return _mm256_or_ps(a, b);
258 Vec256<float>
inline operator^(
const Vec256<float>& a,
const Vec256<float>& b) {
259 return _mm256_xor_ps(a, b);
263 void convert(
const float* src,
float* dst, int64_t n) {
266 for (i = 0; i <= (n - Vec256<float>::size()); i += Vec256<float>::size()) {
267 _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
277 Vec256<float>
inline fmadd(
const Vec256<float>& a,
const Vec256<float>& b,
const Vec256<float>& c) {
278 return _mm256_fmadd_ps(a, b, c);
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Flush-To-Zero and Denormals-Are-Zero mode.