Caffe2 - C++ API
A deep learning, cross platform ML framework
math_cpu_avx2.cc
1 // Implements the math functions for CPU.
2 // The implementation in this file allows us to route the underlying numerical
3 // computation library to different compiler options (-mno-avx2 or -mavx2).
4 
5 #include <immintrin.h>
6 #include <cmath>
7 #include <cstdint>
8 
9 using std::uint64_t;
10 using std::uint8_t;
11 
12 namespace caffe2 {
13 
14 namespace math {
15 
16 static constexpr double QEPSILON = 1e-8;
17 
18 void quantize_and_compress__avx2(
19  const float* input_data,
20  uint8_t* output_data,
21  uint64_t input_size,
22  uint64_t bitwidth,
23  bool random,
24  const float* random_buffer) {
25  __m256i shuffle_mask_v = _mm256_set_epi8(
26  0xff,
27  0xff,
28  0xff,
29  0xff,
30  0xff,
31  0xff,
32  0xff,
33  0xff,
34  0xff,
35  0xff,
36  0xff,
37  0xff,
38  0x0c,
39  0x08,
40  0x04,
41  0x00,
42  0xff,
43  0xff,
44  0xff,
45  0xff,
46  0xff,
47  0xff,
48  0xff,
49  0xff,
50  0xff,
51  0xff,
52  0xff,
53  0xff,
54  0x0c,
55  0x08,
56  0x04,
57  0x00);
58  __m256i permute_mask_v =
59  _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
60 
61  uint64_t data_per_byte = 8 / bitwidth;
62  uint64_t tail = input_size % data_per_byte;
63  tail = tail ? data_per_byte - tail : 0;
64  uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
65 
66  // basic info
67  float minimum_element = INFINITY, maximum_element = -INFINITY;
68  for (auto i = 0; i < input_size; ++i) {
69  minimum_element =
70  (input_data[i] < minimum_element) ? input_data[i] : minimum_element;
71  maximum_element =
72  (input_data[i] > maximum_element) ? input_data[i] : maximum_element;
73  }
74  output_data[0] = bitwidth;
75  output_data[1] = tail;
76  reinterpret_cast<float*>(output_data + 2)[0] = minimum_element;
77  reinterpret_cast<float*>(output_data + 2)[1] = maximum_element;
78 
79  float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
80  float gap_inverse = 1. / (gap + QEPSILON);
81  uint8_t max_q = (1 << bitwidth) - 1;
82  uint64_t bit_start = 0;
83  if (random) {
84  for (int start = 0; start < input_size; start += segment_size) {
85  uint64_t stride = start + segment_size <= input_size ? segment_size
86  : input_size - start;
87  int i = 0;
88  constexpr int VLEN = 8;
89  for (; i < stride / VLEN * VLEN; i += VLEN) {
90  __m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]);
91  __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
92  __m256 thetimes_v = _mm256_mul_ps(
93  _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
94  _mm256_set1_ps(gap_inverse));
95  __m256 rounded_v = _mm256_floor_ps(_mm256_add_ps(thetimes_v, r_v));
96  rounded_v = _mm256_max_ps(
97  _mm256_setzero_ps(),
98  _mm256_min_ps(_mm256_set1_ps(max_q), rounded_v));
99  __m256i qval_v = _mm256_cvtps_epi32(rounded_v);
100  __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
101  reinterpret_cast<const __m128i*>(output_data + 10 + i)));
102  orval_v =
103  _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
104  orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
105  orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
106  *reinterpret_cast<int64_t*>(output_data + 10 + i) =
107  _mm256_extract_epi64(orval_v, 0);
108  }
109  for (; i < stride; ++i) {
110  float fval = input_data[start + i];
111  float thetimes = (fval - minimum_element) * gap_inverse;
112  float rounded = floor(thetimes + random_buffer[start + i]);
113  rounded = rounded < static_cast<float>(max_q)
114  ? rounded
115  : static_cast<float>(max_q);
116  rounded = rounded > 0.0f ? rounded : 0.0f;
117  uint8_t qval = rounded;
118 
119  uint8_t orval = output_data[10 + i];
120  output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
121  }
122  bit_start += bitwidth;
123  }
124  } else {
125  // !random
126  for (int start = 0; start < input_size; start += segment_size) {
127  uint64_t stride = start + segment_size <= input_size ? segment_size
128  : input_size - start;
129  int i = 0;
130  constexpr int VLEN = 8;
131  for (; i < stride / VLEN * VLEN; i += VLEN) {
132  __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
133  __m256 thetimes_v = _mm256_mul_ps(
134  _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
135  _mm256_set1_ps(gap_inverse));
136  thetimes_v = _mm256_max_ps(
137  _mm256_setzero_ps(),
138  _mm256_min_ps(_mm256_set1_ps(max_q), thetimes_v));
139  __m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps(
140  thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
141  __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
142  reinterpret_cast<const __m128i*>(output_data + 10 + i)));
143  orval_v =
144  _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
145  orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
146  orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
147  *reinterpret_cast<int64_t*>(output_data + 10 + i) =
148  _mm256_extract_epi64(orval_v, 0);
149  }
150  for (; i < stride; ++i) {
151  float fval = input_data[start + i];
152  float thetimes = (fval - minimum_element) * gap_inverse;
153  thetimes = thetimes < static_cast<float>(max_q)
154  ? thetimes
155  : static_cast<float>(max_q);
156  thetimes = thetimes > 0.0f ? thetimes : 0.0f;
157  uint8_t qval = nearbyint(thetimes);
158 
159  uint8_t orval = output_data[10 + i];
160  output_data[10 + i] = orval | static_cast<uint8_t>(qval << bit_start);
161  }
162  bit_start += bitwidth;
163  }
164  } // !random
165 }
166 
167 void decompress_and_dequantize__avx2(
168  const uint8_t* input_data,
169  float* output_data,
170  uint64_t input_size) {
171  // basic info
172  const float minimum_element =
173  reinterpret_cast<const float*>(input_data + 2)[0];
174  const float maximum_element =
175  reinterpret_cast<const float*>(input_data + 2)[1];
176  const uint64_t bitwidth = input_data[0];
177  const float gap =
178  (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
179  QEPSILON; // for exact recovering
180 
181  const uint64_t tail = input_data[1];
182 
183  const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
184  // decoding
185  uint64_t bit_start = 0;
186  const uint64_t segment_size = input_size - 10;
187  for (int start = 0; start < output_size; start += segment_size) {
188  uint64_t stride = start + segment_size <= output_size ? segment_size
189  : output_size - start;
190  uint8_t mask = (1 << bitwidth) - 1;
191  int i = 0;
192  // Can process 8 elements at a time because we need to expand uint8_t
193  // to int32_t to use epi32 vector instructions.
194  constexpr int VLEN = 8;
195  for (; i < stride / VLEN * VLEN; i += VLEN) {
196  __m128i in_v = _mm_lddqu_si128(
197  reinterpret_cast<const __m128i*>(input_data + 10 + i));
198  __m256i out_epi32_v = _mm256_and_si256(
199  _mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start),
200  _mm256_set1_epi32(mask));
201  __m256 out_v = _mm256_fmadd_ps(
202  _mm256_cvtepi32_ps(out_epi32_v),
203  _mm256_set1_ps(gap),
204  _mm256_set1_ps(minimum_element));
205  _mm256_storeu_ps(output_data + start + i, out_v);
206  }
207  for (; i < stride; ++i) {
208  output_data[start + i] =
209  ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
210  }
211  bit_start += bitwidth;
212  }
213 }
214 
215 } // namespace math
216 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13