Caffe2 - C++ API
A deep learning, cross platform ML framework
vec256_double.h
1 #pragma once
2 
3 #include <ATen/cpu/vec256/intrinsics.h>
4 #include <ATen/cpu/vec256/vec256_base.h>
5 #if defined(__AVX__) && !defined(_MSC_VER)
6 #include <sleef.h>
7 #endif
8 
9 namespace at {
10 namespace vec256 {
11 // See Note [Acceptable use of anonymous namespace in header]
12 namespace {
13 
14 #if defined(__AVX__) && !defined(_MSC_VER)
15 
16 template <> class Vec256<double> {
17 private:
18  __m256d values;
19 public:
20  static constexpr int size() {
21  return 4;
22  }
23  Vec256() {}
24  Vec256(__m256d v) : values(v) {}
25  Vec256(double val) {
26  values = _mm256_set1_pd(val);
27  }
28  Vec256(double val1, double val2, double val3, double val4) {
29  values = _mm256_setr_pd(val1, val2, val3, val4);
30  }
31  operator __m256d() const {
32  return values;
33  }
34  template <int64_t mask>
35  static Vec256<double> blend(const Vec256<double>& a, const Vec256<double>& b) {
36  return _mm256_blend_pd(a.values, b.values, mask);
37  }
38  static Vec256<double> blendv(const Vec256<double>& a, const Vec256<double>& b,
39  const Vec256<double>& mask) {
40  return _mm256_blendv_pd(a.values, b.values, mask.values);
41  }
42  static Vec256<double> arange(double base = 0., double step = 1.) {
43  return Vec256<double>(base, base + step, base + 2 * step, base + 3 * step);
44  }
45  static Vec256<double> set(const Vec256<double>& a, const Vec256<double>& b,
46  int64_t count = size()) {
47  switch (count) {
48  case 0:
49  return a;
50  case 1:
51  return blend<1>(a, b);
52  case 2:
53  return blend<3>(a, b);
54  case 3:
55  return blend<7>(a, b);
56  }
57  return b;
58  }
59  static Vec256<double> loadu(const void* ptr, int64_t count = size()) {
60  if (count == size())
61  return _mm256_loadu_pd(reinterpret_cast<const double*>(ptr));
62 
63  __at_align32__ double tmp_values[size()];
64  std::memcpy(
65  tmp_values,
66  reinterpret_cast<const double*>(ptr),
67  count * sizeof(double));
68  return _mm256_load_pd(tmp_values);
69  }
70  void store(void* ptr, int count = size()) const {
71  if (count == size()) {
72  _mm256_storeu_pd(reinterpret_cast<double*>(ptr), values);
73  } else if (count > 0) {
74  double tmp_values[size()];
75  _mm256_storeu_pd(reinterpret_cast<double*>(tmp_values), values);
76  std::memcpy(ptr, tmp_values, count * sizeof(double));
77  }
78  }
79  const double& operator[](int idx) const = delete;
80  double& operator[](int idx) = delete;
81  Vec256<double> map(double (*f)(double)) const {
82  __at_align32__ double tmp[4];
83  store(tmp);
84  for (int64_t i = 0; i < 4; i++) {
85  tmp[i] = f(tmp[i]);
86  }
87  return loadu(tmp);
88  }
89  Vec256<double> abs() const {
90  auto mask = _mm256_set1_pd(-0.f);
91  return _mm256_andnot_pd(mask, values);
92  }
93  Vec256<double> acos() const {
94  return Vec256<double>(Sleef_acosd4_u10(values));
95  }
96  Vec256<double> asin() const {
97  return Vec256<double>(Sleef_asind4_u10(values));
98  }
99  Vec256<double> atan() const {
100  return Vec256<double>(Sleef_atand4_u10(values));
101  }
102  Vec256<double> erf() const {
103  return Vec256<double>(Sleef_erfd4_u10(values));
104  }
105  Vec256<double> erfc() const {
106  return Vec256<double>(Sleef_erfcd4_u15(values));
107  }
108  Vec256<double> exp() const {
109  return Vec256<double>(Sleef_expd4_u10(values));
110  }
111  Vec256<double> expm1() const {
112  return Vec256<double>(Sleef_expm1d4_u10(values));
113  }
114  Vec256<double> log() const {
115  return Vec256<double>(Sleef_logd4_u10(values));
116  }
117  Vec256<double> log2() const {
118  return Vec256<double>(Sleef_log2d4_u10(values));
119  }
120  Vec256<double> log10() const {
121  return Vec256<double>(Sleef_log10d4_u10(values));
122  }
123  Vec256<double> log1p() const {
124  return Vec256<double>(Sleef_log1pd4_u10(values));
125  }
126  Vec256<double> sin() const {
127  return map(std::sin);
128  }
129  Vec256<double> sinh() const {
130  return map(std::sinh);
131  }
132  Vec256<double> cos() const {
133  return map(std::cos);
134  }
135  Vec256<double> cosh() const {
136  return map(std::cos);
137  }
138  Vec256<double> ceil() const {
139  return _mm256_ceil_pd(values);
140  }
141  Vec256<double> floor() const {
142  return _mm256_floor_pd(values);
143  }
144  Vec256<double> neg() const {
145  return _mm256_xor_pd(_mm256_set1_pd(-0.), values);
146  }
147  Vec256<double> round() const {
148  return _mm256_round_pd(values, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
149  }
150  Vec256<double> tan() const {
151  return map(std::tan);
152  }
153  Vec256<double> tanh() const {
154  return Vec256<double>(Sleef_tanhd4_u10(values));
155  }
156  Vec256<double> trunc() const {
157  return _mm256_round_pd(values, (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC));
158  }
159  Vec256<double> sqrt() const {
160  return _mm256_sqrt_pd(values);
161  }
162  Vec256<double> reciprocal() const {
163  return _mm256_div_pd(_mm256_set1_pd(1), values);
164  }
165  Vec256<double> rsqrt() const {
166  return _mm256_div_pd(_mm256_set1_pd(1), _mm256_sqrt_pd(values));
167  }
168  Vec256<double> pow(const Vec256<double> &b) const {
169  return Vec256<double>(Sleef_powd4_u10(values, b));
170  }
171  // Comparison using the _CMP_**_OQ predicate.
172  // `O`: get false if an operand is NaN
173  // `Q`: do not raise if an operand is NaN
174  Vec256<double> operator==(const Vec256<double>& other) const {
175  return _mm256_cmp_pd(values, other.values, _CMP_EQ_OQ);
176  }
177 
178  Vec256<double> operator!=(const Vec256<double>& other) const {
179  return _mm256_cmp_pd(values, other.values, _CMP_NEQ_OQ);
180  }
181 
182  Vec256<double> operator<(const Vec256<double>& other) const {
183  return _mm256_cmp_pd(values, other.values, _CMP_LT_OQ);
184  }
185 
186  Vec256<double> operator<=(const Vec256<double>& other) const {
187  return _mm256_cmp_pd(values, other.values, _CMP_LE_OQ);
188  }
189 
190  Vec256<double> operator>(const Vec256<double>& other) const {
191  return _mm256_cmp_pd(values, other.values, _CMP_GT_OQ);
192  }
193 
194  Vec256<double> operator>=(const Vec256<double>& other) const {
195  return _mm256_cmp_pd(values, other.values, _CMP_GE_OQ);
196  }
197 };
198 
199 template <>
200 Vec256<double> inline operator+(const Vec256<double>& a, const Vec256<double>& b) {
201  return _mm256_add_pd(a, b);
202 }
203 
204 template <>
205 Vec256<double> inline operator-(const Vec256<double>& a, const Vec256<double>& b) {
206  return _mm256_sub_pd(a, b);
207 }
208 
209 template <>
210 Vec256<double> inline operator*(const Vec256<double>& a, const Vec256<double>& b) {
211  return _mm256_mul_pd(a, b);
212 }
213 
214 template <>
215 Vec256<double> inline operator/(const Vec256<double>& a, const Vec256<double>& b) {
216  return _mm256_div_pd(a, b);
217 }
218 
219 // Implements the IEEE 754 201X `maximum` operation, which propagates NaN if
220 // either input is a NaN.
221 template <>
222 Vec256<double> inline maximum(const Vec256<double>& a, const Vec256<double>& b) {
223  Vec256<double> max = _mm256_max_pd(a, b);
224  Vec256<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
225  // Exploit the fact that all-ones is a NaN.
226  return _mm256_or_pd(max, isnan);
227 }
228 
229 // Implements the IEEE 754 201X `minimum` operation, which propagates NaN if
230 // either input is a NaN.
231 template <>
232 Vec256<double> inline minimum(const Vec256<double>& a, const Vec256<double>& b) {
233  Vec256<double> min = _mm256_min_pd(a, b);
234  Vec256<double> isnan = _mm256_cmp_pd(a, b, _CMP_UNORD_Q);
235  // Exploit the fact that all-ones is a NaN.
236  return _mm256_or_pd(min, isnan);
237 }
238 
239 template <>
240 Vec256<double> inline operator&(const Vec256<double>& a, const Vec256<double>& b) {
241  return _mm256_and_pd(a, b);
242 }
243 
244 template <>
245 Vec256<double> inline operator|(const Vec256<double>& a, const Vec256<double>& b) {
246  return _mm256_or_pd(a, b);
247 }
248 
249 template <>
250 Vec256<double> inline operator^(const Vec256<double>& a, const Vec256<double>& b) {
251  return _mm256_xor_pd(a, b);
252 }
253 
254 template <>
255 void convert(const double* src, double* dst, int64_t n) {
256  int64_t i;
257 #pragma unroll
258  for (i = 0; i <= (n - Vec256<double>::size()); i += Vec256<double>::size()) {
259  _mm256_storeu_pd(dst + i, _mm256_loadu_pd(src + i));
260  }
261 #pragma unroll
262  for (; i < n; i++) {
263  dst[i] = src[i];
264  }
265 }
266 
267 #ifdef __AVX2__
268 template <>
269 Vec256<double> inline fmadd(const Vec256<double>& a, const Vec256<double>& b, const Vec256<double>& c) {
270  return _mm256_fmadd_pd(a, b, c);
271 }
272 #endif
273 
274 #endif
275 
276 }}}
C10_HOST_DEVICE Half operator+(const Half &a, const Half &b)
Arithmetic.
Definition: Half-inl.h:56
Flush-To-Zero and Denormals-Are-Zero mode.