Caffe2 - C++ API
A deep learning, cross platform ML framework
vec256_float.h
1 #pragma once
2 
3 #include <ATen/cpu/vec256/intrinsics.h>
4 #include <ATen/cpu/vec256/vec256_base.h>
5 #if defined(__AVX__) && !defined(_MSC_VER)
6 #include <sleef.h>
7 #endif
8 
9 namespace at {
10 namespace vec256 {
11 // See Note [Acceptable use of anonymous namespace in header]
12 namespace {
13 
14 #if defined(__AVX__) && !defined(_MSC_VER)
15 
16 template <> class Vec256<float> {
17 private:
18  __m256 values;
19 public:
20  static constexpr int size() {
21  return 8;
22  }
23  Vec256() {}
24  Vec256(__m256 v) : values(v) {}
25  Vec256(float val) {
26  values = _mm256_set1_ps(val);
27  }
28  Vec256(float val1, float val2, float val3, float val4,
29  float val5, float val6, float val7, float val8) {
30  values = _mm256_setr_ps(val1, val2, val3, val4, val5, val6, val7, val8);
31  }
32  operator __m256() const {
33  return values;
34  }
35  template <int64_t mask>
36  static Vec256<float> blend(const Vec256<float>& a, const Vec256<float>& b) {
37  return _mm256_blend_ps(a.values, b.values, mask);
38  }
39  static Vec256<float> blendv(const Vec256<float>& a, const Vec256<float>& b,
40  const Vec256<float>& mask) {
41  return _mm256_blendv_ps(a.values, b.values, mask.values);
42  }
43  static Vec256<float> arange(float base = 0.f, float step = 1.f) {
44  return Vec256<float>(
45  base, base + step, base + 2 * step, base + 3 * step,
46  base + 4 * step, base + 5 * step, base + 6 * step, base + 7 * step);
47  }
48  static Vec256<float> set(const Vec256<float>& a, const Vec256<float>& b,
49  int64_t count = size()) {
50  switch (count) {
51  case 0:
52  return a;
53  case 1:
54  return blend<1>(a, b);
55  case 2:
56  return blend<3>(a, b);
57  case 3:
58  return blend<7>(a, b);
59  case 4:
60  return blend<15>(a, b);
61  case 5:
62  return blend<31>(a, b);
63  case 6:
64  return blend<63>(a, b);
65  case 7:
66  return blend<127>(a, b);
67  }
68  return b;
69  }
70  static Vec256<float> loadu(const void* ptr, int64_t count = size()) {
71  if (count == size())
72  return _mm256_loadu_ps(reinterpret_cast<const float*>(ptr));
73  __at_align32__ float tmp_values[size()];
74  std::memcpy(
75  tmp_values, reinterpret_cast<const float*>(ptr), count * sizeof(float));
76  return _mm256_loadu_ps(tmp_values);
77  }
78  void store(void* ptr, int64_t count = size()) const {
79  if (count == size()) {
80  _mm256_storeu_ps(reinterpret_cast<float*>(ptr), values);
81  } else if (count > 0) {
82  float tmp_values[size()];
83  _mm256_storeu_ps(reinterpret_cast<float*>(tmp_values), values);
84  std::memcpy(ptr, tmp_values, count * sizeof(float));
85  }
86  }
87  const float& operator[](int idx) const = delete;
88  float& operator[](int idx) = delete;
89  Vec256<float> map(float (*f)(float)) const {
90  __at_align32__ float tmp[8];
91  store(tmp);
92  for (int64_t i = 0; i < 8; i++) {
93  tmp[i] = f(tmp[i]);
94  }
95  return loadu(tmp);
96  }
97  Vec256<float> abs() const {
98  auto mask = _mm256_set1_ps(-0.f);
99  return _mm256_andnot_ps(mask, values);
100  }
101  Vec256<float> acos() const {
102  return Vec256<float>(Sleef_acosf8_u10(values));
103  }
104  Vec256<float> asin() const {
105  return Vec256<float>(Sleef_asinf8_u10(values));
106  }
107  Vec256<float> atan() const {
108  return Vec256<float>(Sleef_atanf8_u10(values));
109  }
110  Vec256<float> erf() const {
111  return Vec256<float>(Sleef_erff8_u10(values));
112  }
113  Vec256<float> erfc() const {
114  return Vec256<float>(Sleef_erfcf8_u15(values));
115  }
116  Vec256<float> exp() const {
117  return Vec256<float>(Sleef_expf8_u10(values));
118  }
119  Vec256<float> expm1() const {
120  return Vec256<float>(Sleef_expm1f8_u10(values));
121  }
122  Vec256<float> log() const {
123  return Vec256<float>(Sleef_logf8_u10(values));
124  }
125  Vec256<float> log2() const {
126  return Vec256<float>(Sleef_log2f8_u10(values));
127  }
128  Vec256<float> log10() const {
129  return Vec256<float>(Sleef_log10f8_u10(values));
130  }
131  Vec256<float> log1p() const {
132  return Vec256<float>(Sleef_log1pf8_u10(values));
133  }
134  Vec256<float> sin() const {
135  return map(std::sin);
136  }
137  Vec256<float> sinh() const {
138  return map(std::sinh);
139  }
140  Vec256<float> cos() const {
141  return map(std::cos);
142  }
143  Vec256<float> cosh() const {
144  return map(std::cosh);
145  }
146  Vec256<float> ceil() const {
147  return _mm256_ceil_ps(values);
148  }
149  Vec256<float> floor() const {
150  return _mm256_floor_ps(values);
151  }
152  Vec256<float> neg() const {
153  return _mm256_xor_ps(_mm256_set1_ps(-0.f), values);
154  }
155  Vec256<float> round() const {
156  return _mm256_round_ps(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
157  }
158  Vec256<float> tan() const {
159  return map(std::tan);
160  }
161  Vec256<float> tanh() const {
162  return Vec256<float>(Sleef_tanhf8_u10(values));
163  }
164  Vec256<float> trunc() const {
165  return _mm256_round_ps(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
166  }
167  Vec256<float> sqrt() const {
168  return _mm256_sqrt_ps(values);
169  }
170  Vec256<float> reciprocal() const {
171  return _mm256_div_ps(_mm256_set1_ps(1), values);
172  }
173  Vec256<float> rsqrt() const {
174  return _mm256_div_ps(_mm256_set1_ps(1), _mm256_sqrt_ps(values));
175  }
176  Vec256<float> pow(const Vec256<float> &b) const {
177  return Vec256<float>(Sleef_powf8_u10(values, b));
178  }
179  // Comparison using the _CMP_**_OQ predicate.
180  // `O`: get false if an operand is NaN
181  // `Q`: do not raise if an operand is NaN
182  Vec256<float> operator==(const Vec256<float>& other) const {
183  return _mm256_cmp_ps(values, other.values, _CMP_EQ_OQ);
184  }
185 
186  Vec256<float> operator!=(const Vec256<float>& other) const {
187  return _mm256_cmp_ps(values, other.values, _CMP_NEQ_OQ);
188  }
189 
190  Vec256<float> operator<(const Vec256<float>& other) const {
191  return _mm256_cmp_ps(values, other.values, _CMP_LT_OQ);
192  }
193 
194  Vec256<float> operator<=(const Vec256<float>& other) const {
195  return _mm256_cmp_ps(values, other.values, _CMP_LE_OQ);
196  }
197 
198  Vec256<float> operator>(const Vec256<float>& other) const {
199  return _mm256_cmp_ps(values, other.values, _CMP_GT_OQ);
200  }
201 
202  Vec256<float> operator>=(const Vec256<float>& other) const {
203  return _mm256_cmp_ps(values, other.values, _CMP_GE_OQ);
204  }
205 };
206 
207 template <>
208 Vec256<float> inline operator+(const Vec256<float>& a, const Vec256<float>& b) {
209  return _mm256_add_ps(a, b);
210 }
211 
212 template <>
213 Vec256<float> inline operator-(const Vec256<float>& a, const Vec256<float>& b) {
214  return _mm256_sub_ps(a, b);
215 }
216 
217 template <>
218 Vec256<float> inline operator*(const Vec256<float>& a, const Vec256<float>& b) {
219  return _mm256_mul_ps(a, b);
220 }
221 
222 template <>
223 Vec256<float> inline operator/(const Vec256<float>& a, const Vec256<float>& b) {
224  return _mm256_div_ps(a, b);
225 }
226 
227 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
228 // either input is a NaN.
229 template <>
230 Vec256<float> inline maximum(const Vec256<float>& a, const Vec256<float>& b) {
231  Vec256<float> max = _mm256_max_ps(a, b);
232  Vec256<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
233  // Exploit the fact that all-ones is a NaN.
234  return _mm256_or_ps(max, isnan);
235 }
236 
237 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
238 // either input is a NaN.
239 template <>
240 Vec256<float> inline minimum(const Vec256<float>& a, const Vec256<float>& b) {
241  Vec256<float> min = _mm256_min_ps(a, b);
242  Vec256<float> isnan = _mm256_cmp_ps(a, b, _CMP_UNORD_Q);
243  // Exploit the fact that all-ones is a NaN.
244  return _mm256_or_ps(min, isnan);
245 }
246 
247 template <>
248 Vec256<float> inline operator&(const Vec256<float>& a, const Vec256<float>& b) {
249  return _mm256_and_ps(a, b);
250 }
251 
252 template <>
253 Vec256<float> inline operator|(const Vec256<float>& a, const Vec256<float>& b) {
254  return _mm256_or_ps(a, b);
255 }
256 
257 template <>
258 Vec256<float> inline operator^(const Vec256<float>& a, const Vec256<float>& b) {
259  return _mm256_xor_ps(a, b);
260 }
261 
262 template <>
263 void convert(const float* src, float* dst, int64_t n) {
264  int64_t i;
265 #pragma unroll
266  for (i = 0; i <= (n - Vec256<float>::size()); i += Vec256<float>::size()) {
267  _mm256_storeu_ps(dst + i, _mm256_loadu_ps(src + i));
268  }
269 #pragma unroll
270  for (; i < n; i++) {
271  dst[i] = src[i];
272  }
273 }
274 
275 #ifdef __AVX2__
276 template <>
277 Vec256<float> inline fmadd(const Vec256<float>& a, const Vec256<float>& b, const Vec256<float>& c) {
278  return _mm256_fmadd_ps(a, b, c);
279 }
280 #endif
281 
282 #endif
283 
284 }}}
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Definition: Half-inl.h:56
Flush-To-Zero and Denormals-Are-Zero mode.