16 static constexpr
double QEPSILON = 1e-8;
18 void quantize_and_compress__avx2(
19 const float* input_data,
24 const float* random_buffer) {
25 __m256i shuffle_mask_v = _mm256_set_epi8(
58 __m256i permute_mask_v =
59 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
61 uint64_t data_per_byte = 8 / bitwidth;
62 uint64_t tail = input_size % data_per_byte;
63 tail = tail ? data_per_byte - tail : 0;
64 uint64_t segment_size = (input_size + data_per_byte - 1) / data_per_byte;
67 float minimum_element = INFINITY, maximum_element = -INFINITY;
68 for (
auto i = 0; i < input_size; ++i) {
70 (input_data[i] < minimum_element) ? input_data[i] : minimum_element;
72 (input_data[i] > maximum_element) ? input_data[i] : maximum_element;
74 output_data[0] = bitwidth;
75 output_data[1] = tail;
76 reinterpret_cast<float*
>(output_data + 2)[0] = minimum_element;
77 reinterpret_cast<float*
>(output_data + 2)[1] = maximum_element;
79 float gap = (maximum_element - minimum_element) / ((1 << bitwidth) - 1.0f);
80 float gap_inverse = 1. / (gap + QEPSILON);
81 uint8_t max_q = (1 << bitwidth) - 1;
82 uint64_t bit_start = 0;
84 for (
int start = 0; start < input_size; start += segment_size) {
85 uint64_t stride = start + segment_size <= input_size ? segment_size
88 constexpr
int VLEN = 8;
89 for (; i < stride / VLEN * VLEN; i += VLEN) {
90 __m256 r_v = _mm256_loadu_ps(&random_buffer[start + i]);
91 __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
92 __m256 thetimes_v = _mm256_mul_ps(
93 _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
94 _mm256_set1_ps(gap_inverse));
95 __m256 rounded_v = _mm256_floor_ps(_mm256_add_ps(thetimes_v, r_v));
96 rounded_v = _mm256_max_ps(
98 _mm256_min_ps(_mm256_set1_ps(max_q), rounded_v));
99 __m256i qval_v = _mm256_cvtps_epi32(rounded_v);
100 __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
101 reinterpret_cast<const __m128i*>(output_data + 10 + i)));
103 _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
104 orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
105 orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
106 *
reinterpret_cast<int64_t*
>(output_data + 10 + i) =
107 _mm256_extract_epi64(orval_v, 0);
109 for (; i < stride; ++i) {
110 float fval = input_data[start + i];
111 float thetimes = (fval - minimum_element) * gap_inverse;
112 float rounded = floor(thetimes + random_buffer[start + i]);
113 rounded = rounded < static_cast<float>(max_q)
115 : static_cast<float>(max_q);
116 rounded = rounded > 0.0f ? rounded : 0.0f;
117 uint8_t qval = rounded;
119 uint8_t orval = output_data[10 + i];
120 output_data[10 + i] = orval |
static_cast<uint8_t
>(qval << bit_start);
122 bit_start += bitwidth;
126 for (
int start = 0; start < input_size; start += segment_size) {
127 uint64_t stride = start + segment_size <= input_size ? segment_size
128 : input_size - start;
130 constexpr
int VLEN = 8;
131 for (; i < stride / VLEN * VLEN; i += VLEN) {
132 __m256 fval_v = _mm256_loadu_ps(input_data + start + i);
133 __m256 thetimes_v = _mm256_mul_ps(
134 _mm256_sub_ps(fval_v, _mm256_set1_ps(minimum_element)),
135 _mm256_set1_ps(gap_inverse));
136 thetimes_v = _mm256_max_ps(
138 _mm256_min_ps(_mm256_set1_ps(max_q), thetimes_v));
139 __m256i qval_v = _mm256_cvtps_epi32(_mm256_round_ps(
140 thetimes_v, _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
141 __m256i orval_v = _mm256_cvtepu8_epi32(_mm_lddqu_si128(
142 reinterpret_cast<const __m128i*>(output_data + 10 + i)));
144 _mm256_or_si256(orval_v, _mm256_slli_epi32(qval_v, bit_start));
145 orval_v = _mm256_shuffle_epi8(orval_v, shuffle_mask_v);
146 orval_v = _mm256_permutevar8x32_epi32(orval_v, permute_mask_v);
147 *
reinterpret_cast<int64_t*
>(output_data + 10 + i) =
148 _mm256_extract_epi64(orval_v, 0);
150 for (; i < stride; ++i) {
151 float fval = input_data[start + i];
152 float thetimes = (fval - minimum_element) * gap_inverse;
153 thetimes = thetimes < static_cast<float>(max_q)
155 : static_cast<float>(max_q);
156 thetimes = thetimes > 0.0f ? thetimes : 0.0f;
157 uint8_t qval = nearbyint(thetimes);
159 uint8_t orval = output_data[10 + i];
160 output_data[10 + i] = orval |
static_cast<uint8_t
>(qval << bit_start);
162 bit_start += bitwidth;
167 void decompress_and_dequantize__avx2(
168 const uint8_t* input_data,
170 uint64_t input_size) {
172 const float minimum_element =
173 reinterpret_cast<const float*
>(input_data + 2)[0];
174 const float maximum_element =
175 reinterpret_cast<const float*
>(input_data + 2)[1];
176 const uint64_t bitwidth = input_data[0];
178 (maximum_element - minimum_element) / ((1 << bitwidth) - 1.f) +
181 const uint64_t tail = input_data[1];
183 const uint64_t output_size = (input_size - 10) * (8 / bitwidth) - tail;
185 uint64_t bit_start = 0;
186 const uint64_t segment_size = input_size - 10;
187 for (
int start = 0; start < output_size; start += segment_size) {
188 uint64_t stride = start + segment_size <= output_size ? segment_size
189 : output_size - start;
190 uint8_t mask = (1 << bitwidth) - 1;
194 constexpr
int VLEN = 8;
195 for (; i < stride / VLEN * VLEN; i += VLEN) {
196 __m128i in_v = _mm_lddqu_si128(
197 reinterpret_cast<const __m128i*>(input_data + 10 + i));
198 __m256i out_epi32_v = _mm256_and_si256(
199 _mm256_srli_epi32(_mm256_cvtepu8_epi32(in_v), bit_start),
200 _mm256_set1_epi32(mask));
201 __m256 out_v = _mm256_fmadd_ps(
202 _mm256_cvtepi32_ps(out_epi32_v),
204 _mm256_set1_ps(minimum_element));
205 _mm256_storeu_ps(output_data + start + i, out_v);
207 for (; i < stride; ++i) {
208 output_data[start + i] =
209 ((input_data[10 + i] >> bit_start) & mask) * gap + minimum_element;
211 bit_start += bitwidth;
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...