Caffe2 - C++ API
A deep learning, cross platform ML framework
vec256.h
1 #pragma once
2 
3 #include <ATen/cpu/vec256/intrinsics.h>
4 
5 #include <ATen/cpu/vec256/vec256_base.h>
6 #include <ATen/cpu/vec256/vec256_float.h>
7 #include <ATen/cpu/vec256/vec256_double.h>
8 #include <ATen/cpu/vec256/vec256_int.h>
9 
10 #include <algorithm>
11 #include <cstddef>
12 #include <cstdint>
13 #include <cstring>
14 #include <iostream>
15 
16 namespace at {
17 namespace vec256 {
18 
19 // Note [Acceptable use of anonymous namespace in header]
20 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
21 // Yes you saw right, this is an anonymous namespace in a header. This header,
22 // and all of its subheaders, REQUIRE their code to be entirely inlined into
23 // the compilation unit that uses them. It's important that these functions have
24 // internal linkage so that kernels for different architectures don't get
25 // combined during linking. It's sufficient to label functions "static", but
26 // class methods must be an unnamed namespace to have internal linkage (since
27 // static means something different in the context of classes).
28 namespace {
29 
30 template <typename T>
31 std::ostream& operator<<(std::ostream& stream, const Vec256<T>& vec) {
32  T buf[Vec256<T>::size()];
33  vec.store(buf);
34  stream << "vec[";
35  for (int i = 0; i != Vec256<T>::size(); i++) {
36  if (i != 0) {
37  stream << ", ";
38  }
39  stream << buf[i];
40  }
41  stream << "]";
42  return stream;
43 }
44 
45 
46 #if defined(__AVX__) && !defined(_MSC_VER)
47 
48 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
49 
50 template<>
51 Vec256<float> cast<float, double>(const Vec256<double>& src) {
52  return _mm256_castpd_ps(src);
53 }
54 
55 template<>
56 Vec256<double> cast<double, float>(const Vec256<float>& src) {
57  return _mm256_castps_pd(src);
58 }
59 
60 #if defined(__AVX2__)
61 
62 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST (AVX2) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
63 
64 #define DEFINE_FLOAT_INT_CAST(int_t, float_t, float_ch) \
65 template<> \
66 Vec256<int_t> cast<int_t, float_t>(const Vec256<float_t>& src) { \
67  return _mm256_castp ## float_ch ## _si256(src); \
68 } \
69 template<> \
70 Vec256<float_t> cast<float_t, int_t>(const Vec256<int_t>& src) { \
71  return _mm256_castsi256_p ## float_ch (src); \
72 }
73 
74 DEFINE_FLOAT_INT_CAST(int64_t, double, d)
75 DEFINE_FLOAT_INT_CAST(int32_t, double, d)
76 DEFINE_FLOAT_INT_CAST(int16_t, double, d)
77 DEFINE_FLOAT_INT_CAST(int64_t, float, s)
78 DEFINE_FLOAT_INT_CAST(int32_t, float, s)
79 DEFINE_FLOAT_INT_CAST(int16_t, float, s)
80 
81 #undef DEFINE_FLOAT_INT_CAST
82 
83 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
84 
85 template<int64_t scale = 1>
86 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<double>>
87 inline gather(const double* base_addr, const Vec256<int64_t>& vindex) {
88  return _mm256_i64gather_pd(base_addr, vindex, scale);
89 }
90 
91 template<int64_t scale = 1>
92 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<float>>
93 inline gather(const float* base_addr, const Vec256<int32_t>& vindex) {
94  return _mm256_i32gather_ps(base_addr, vindex, scale);
95 }
96 
97 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
98 
99 template<int64_t scale = 1>
100 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<double>>
101 inline mask_gather(const Vec256<double>& src, const double* base_addr,
102  const Vec256<int64_t>& vindex, const Vec256<double>& mask) {
103  return _mm256_mask_i64gather_pd(src, base_addr, vindex, mask, scale);
104 }
105 
106 template<int64_t scale = 1>
107 c10::guts::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vec256<float>>
108 inline mask_gather(const Vec256<float>& src, const float* base_addr,
109  const Vec256<int32_t>& vindex, const Vec256<float>& mask) {
110  return _mm256_mask_i32gather_ps(src, base_addr, vindex, mask, scale);
111 }
112 
113 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
114 
115 // Only works for inputs in the range: [-2^51, 2^51]
116 // From: https://stackoverflow.com/a/41148578
117 template<>
118 Vec256<int64_t>
119 inline convert_to_int_of_same_size<double>(const Vec256<double> &src) {
120  auto x = _mm256_add_pd(src, _mm256_set1_pd(0x0018000000000000));
121  return _mm256_sub_epi64(
122  _mm256_castpd_si256(x),
123  _mm256_castpd_si256(_mm256_set1_pd(0x0018000000000000))
124  );
125 }
126 
127 template<>
128 Vec256<int32_t>
129 inline convert_to_int_of_same_size<float>(const Vec256<float> &src) {
130  return _mm256_cvttps_epi32(src);
131 }
132 
133 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
134 
135 template <>
136 std::pair<Vec256<double>, Vec256<double>>
137 inline interleave2<double>(const Vec256<double>& a, const Vec256<double>& b) {
138  // inputs:
139  // a = {a0, a1, a3, a3}
140  // b = {b0, b1, b2, b3}
141 
142  // swap lanes:
143  // a_swapped = {a0, a1, b0, b1}
144  // b_swapped = {a2, a3, b2, b3}
145  static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
146  static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
147  auto a_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_a);
148  auto b_swapped = _mm256_permute2f128_pd(a, b, swap_ctrl_b);
149 
150  // group cols crossing lanes:
151  // return {a0, b0, a1, b1}
152  // {a2, b2, a3, b3}
153  static constexpr int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6); // 0, 2, 1, 3
154  return std::make_pair(_mm256_permute4x64_pd(a_swapped, group_ctrl),
155  _mm256_permute4x64_pd(b_swapped, group_ctrl));
156 }
157 
158 template <>
159 std::pair<Vec256<float>, Vec256<float>>
160 inline interleave2<float>(const Vec256<float>& a, const Vec256<float>& b) {
161  // inputs:
162  // a = {a0, a1, a2, a3, a4, a5, a6, a7}
163  // b = {b0, b1, b2, b3, b4, b5, b6, b7}
164 
165  // swap lanes:
166  // a_swapped = {a0, a1, a2, a3, b0, b1, b2, b3}
167  // b_swapped = {a4, a5, a6, a7, b4, b5, b6, b7}
168  // TODO: can we support caching this?
169  static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
170  static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
171  auto a_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_a);
172  auto b_swapped = _mm256_permute2f128_ps(a, b, swap_ctrl_b);
173 
174  // group cols crossing lanes:
175  // return {a0, b0, a1, b1, a2, b2, a3, b3}
176  // {a4, b4, a5, b5, a6, b6, a7, b7}
177  const __m256i group_ctrl = _mm256_setr_epi32(0, 4, 1, 5, 2, 6, 3, 7);
178  return std::make_pair(_mm256_permutevar8x32_ps(a_swapped, group_ctrl),
179  _mm256_permutevar8x32_ps(b_swapped, group_ctrl));
180 }
181 
182 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
183 
184 template <>
185 std::pair<Vec256<double>, Vec256<double>>
186 inline deinterleave2<double>(const Vec256<double>& a, const Vec256<double>& b) {
187  // inputs:
188  // a = {a0, b0, a1, b1}
189  // b = {a2, b2, a3, b3}
190 
191  // group cols crossing lanes:
192  // a_grouped = {a0, a1, b0, b1}
193  // b_grouped = {a2, a3, b2, b3}
194  static constexpr int group_ctrl = 0 | (2 << 2) | (1 << 4) | (3 << 6); // 0, 2, 1, 3
195  auto a_grouped = _mm256_permute4x64_pd(a, group_ctrl);
196  auto b_grouped = _mm256_permute4x64_pd(b, group_ctrl);
197 
198  // swap lanes:
199  // return {a0, a1, a2, a3}
200  // {b0, b1, b2, b3}
201  static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
202  static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
203  return std::make_pair(_mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_a),
204  _mm256_permute2f128_pd(a_grouped, b_grouped, swap_ctrl_b));
205 }
206 
207 template <>
208 std::pair<Vec256<float>, Vec256<float>>
209 inline deinterleave2<float>(const Vec256<float>& a, const Vec256<float>& b) {
210  // inputs:
211  // a = {a0, b0, a1, b1, a2, b2, a3, b3}
212  // b = {a4, b4, a5, b5, a6, b6, a7, b7}
213 
214  // group cols crossing lanes:
215  // a_grouped = {a0, a1, a2, a3, b0, b1, b2, b3}
216  // b_grouped = {a4, a5, a6, a7, b4, b5, b6, b7}
217  // TODO: can we support caching this?
218  const __m256i group_ctrl = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
219  auto a_grouped = _mm256_permutevar8x32_ps(a, group_ctrl);
220  auto b_grouped = _mm256_permutevar8x32_ps(b, group_ctrl);
221 
222  // swap lanes:
223  // return {a0, a1, a2, a3, a4, a5, a6, a7}
224  // {b0, b1, b2, b3, b4, b5, b6, b7}
225  static constexpr int swap_ctrl_a = 0 | (2 << 4); // 0, 2. 4 bits apart
226  static constexpr int swap_ctrl_b = 1 | (3 << 4); // 1, 3. 4 bits apart
227  return std::make_pair(_mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_a),
228  _mm256_permute2f128_ps(a_grouped, b_grouped, swap_ctrl_b));
229 }
230 
231 #endif // defined(__AVX2__)
232 
233 #endif // defined(__AVX__) && !defined(_MSC_VER)
234 
235 }}}
Flush-To-Zero and Denormals-Are-Zero mode.