1 #include <cstdint>
2 #include <limits>
4 #include <immintrin.h>
6 #ifdef _OPENMP
7 #include <omp.h>
8 #endif
10 #include <fbgemm/QuantUtils.h>
12 namespace caffe2 {
14 namespace internal {
16 template <typename T>
17 void SegmentMomentsAVX2(
18  const int N,
19  const T* src,
20  int64_t* sum,
21  int64_t* sumsq);
23 template <>
24 void SegmentMomentsAVX2<uint8_t>(
25  const int N,
26  const uint8_t* src,
27  int64_t* sum,
28  int64_t* sumsq) {
29  constexpr int kVLen = 16;
30  const int n = N / kVLen * kVLen;
31  const int r = N % kVLen;
32  const __m256i kOneInt16 = _mm256_set1_epi16(0x01);
33  __m256i sum_v = _mm256_setzero_si256();
34  __m256i sumsq_v = _mm256_setzero_si256();
35  for (int i = 0; i < n; i += kVLen) {
36  const __m256i cur_v = _mm256_cvtepu8_epi16(
37  _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i)));
38  sum_v = _mm256_add_epi32(sum_v, _mm256_madd_epi16(cur_v, kOneInt16));
39  sumsq_v = _mm256_add_epi32(sumsq_v, _mm256_madd_epi16(cur_v, cur_v));
40  }
41  int32_t sum_arr[8];
42  int32_t sumsq_arr[8];
43  _mm256_storeu_si256(reinterpret_cast<__m256i*>(sum_arr), sum_v);
44  _mm256_storeu_si256(reinterpret_cast<__m256i*>(sumsq_arr), sumsq_v);
45  for (int i = 0; i < 8; ++i) {
46  *sum += static_cast<int64_t>(sum_arr[i]);
47  *sumsq += static_cast<int64_t>(sumsq_arr[i]);
48  }
49  for (int i = 0; i < r; ++i) {
50  *sum += static_cast<int64_t>(src[n + i]);
51  *sumsq +=
52  static_cast<int64_t>(src[n + i]) * static_cast<int64_t>(src[n + i]);
53  }
54 }
56 template <typename T>
57 void VectorMomentsAVX2(const int N, const T* src, int64_t* sum, int64_t* sumsq);
59 template <>
60 void VectorMomentsAVX2<uint8_t>(
61  const int N,
62  const uint8_t* src,
63  int64_t* sum,
64  int64_t* sumsq) {
65  constexpr int kVLen = 32768;
66  const int n = N / kVLen * kVLen;
67  const int r = N % kVLen;
68  for (int i = 0; i < n; i += kVLen) {
69  SegmentMomentsAVX2<uint8_t>(kVLen, src + i, sum, sumsq);
70  }
71  if (r > 0) {
72  SegmentMomentsAVX2<uint8_t>(r, src + n, sum, sumsq);
73  }
74 }
76 void ComputeQuantizedFusedParamsAVX2(
77  const int N,
78  const int G,
79  const int K,
80  const int32_t X_zero_point,
81  const int32_t* mu,
82  const int32_t* rsig,
83  const int32_t* gamma,
84  int32_t* scale,
85  int32_t* bias) {
86  constexpr int kVLen = 8;
87  const int k = K / kVLen * kVLen;
88  const int r = K % kVLen;
89  for (int n = N - 1; n >= 0; --n) {
90 #ifdef _OPENMP
91 #pragma omp parallel for
92 #endif
93  for (int g = 0; g < G; ++g) {
94  const __m256i mu_v = _mm256_set1_epi32(mu[n * G + g] + X_zero_point);
95  const __m256i rsig_v = _mm256_set1_epi32(rsig[n * G + g]);
96  for (int i = 0; i < k; i += kVLen) {
97  const __m256i gamma_v =
98  _mm256_loadu_si256((const __m256i*)(gamma + g * K + i));
99  const __m256i beta_v =
100  _mm256_loadu_si256((const __m256i*)(bias + g * K + i));
101  __m256i scale_v = _mm256_mullo_epi32(gamma_v, rsig_v);
102  __m256i bias_v =
103  _mm256_sub_epi32(beta_v, _mm256_mullo_epi32(scale_v, mu_v));
104  const int offset = (n * G + g) * K + i;
105  _mm256_storeu_si256((__m256i*)(scale + offset), scale_v);
106  _mm256_storeu_si256((__m256i*)(bias + offset), bias_v);
107  }
108  for (int i = 0; i < r; ++i) {
109  const int offset = (n * G + g) * K + k + i;
110  scale[offset] = gamma[g * K + k + i] * rsig[n * G + g];
111  bias[offset] = bias[g * K + k + i] -
112  scale[offset] * (mu[n * G + g] + X_zero_point);
113  }
114  }
115  }
116 }
118 #define INIT_REQUANTIZE_AVX2 \
119  const __m256i b = _mm256_set1_epi32(params.multiplier); \
120  const __m256i prev_shift_nudge = _mm256_set1_epi64x( \
121  (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL); \
122  const __m256i post_shift_nudge = _mm256_set1_epi64x( \
123  params.target_qparams.zero_point - \
124  (0x8000000000000000ULL >> params.right_shift)); \
125  const __m256i min_v = \
126  _mm256_set1_epi32(std::numeric_limits<uint8_t>::min()); \
127  const __m256i max_v = \
128  _mm256_set1_epi32(std::numeric_limits<uint8_t>::max()); \
129  const __m256i shuffle_mask_v = _mm256_set_epi8( \
130  0xff, \
131  0xff, \
132  0xff, \
133  0xff, \
134  0xff, \
135  0xff, \
136  0xff, \
137  0xff, \
138  0xff, \
139  0xff, \
140  0xff, \
141  0xff, \
142  0x0c, \
143  0x08, \
144  0x04, \
145  0x00, \
146  0xff, \
147  0xff, \
148  0xff, \
149  0xff, \
150  0xff, \
151  0xff, \
152  0xff, \
153  0xff, \
154  0xff, \
155  0xff, \
156  0xff, \
157  0xff, \
158  0x0c, \
159  0x08, \
160  0x04, \
161  0x00); \
162  const __m256i permute_mask_v = \
163  _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00);
165 #define REQUANTIZE_AVX2(params, src, dst) \
166  do { \
167  __m256i a_v = (src); \
168  __m256i a_even_v = a_v; \
169  __m256i a_odd_v = _mm256_srli_si256(a_v, 4); \
170  __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); \
171  __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); \
172  __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, prev_shift_nudge); \
173  __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, prev_shift_nudge); \
174  __m256i even_result_v = _mm256_add_epi64( \
175  _mm256_srli_epi64(even_rounded_v, params.right_shift), \
176  post_shift_nudge); \
177  __m256i odd_result_v = _mm256_add_epi64( \
178  _mm256_srli_epi64(odd_rounded_v, params.right_shift), \
179  post_shift_nudge); \
180  odd_result_v = _mm256_slli_si256(odd_result_v, 4); \
181  __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); \
182  __m256i clipped_v = \
183  _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v)); \
184  clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); \
185  clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); \
186  *(int64_t*)(dst) = _mm256_extract_epi64(clipped_v, 0); \
187  } while (false)
189 template <typename T>
190 void AffineBatchChannelAndRequantizeNCHWAVX2(
191  const int N,
192  const int C,
193  const int HxW,
194  const fbgemm::RequantizationParams& params,
195  const T* X,
196  const int32_t* scale,
197  const int32_t* bias,
198  T* Y);
200 template <>
201 void AffineBatchChannelAndRequantizeNCHWAVX2<uint8_t>(
202  const int N,
203  const int C,
204  const int HxW,
205  const fbgemm::RequantizationParams& params,
206  const uint8_t* X,
207  const int32_t* scale,
208  const int32_t* bias,
209  uint8_t* Y) {
211  constexpr int kVLen = 8;
212  const int outer_size = N * C;
213  const int n = HxW / kVLen * kVLen;
214  const int r = HxW % kVLen;
215 #ifdef _OPENMP
216 #pragma omp parallel for
217 #endif
218  for (int i = 0; i < outer_size; ++i) {
219  const uint8_t* X_ptr = X + i * HxW;
220  uint8_t* Y_ptr = Y + i * HxW;
221  const __m256i scale_v = _mm256_set1_epi32(scale[i]);
222  const __m256i bias_v = _mm256_set1_epi32(bias[i]);
223  for (int j = 0; j < n; j += kVLen) {
224  const __m256i cur_v =
225  _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)(X_ptr + j)));
227  params,
228  _mm256_add_epi32(_mm256_mullo_epi32(cur_v, scale_v), bias_v),
229  (Y_ptr + j));
230  }
231  for (int j = 0; j < r; ++j) {
232  Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
233  static_cast<int32_t>(X_ptr[n + j]) * scale[i] + bias[i], params);
234  }
235  }
236 }
238 template <typename T>
239 void AffineBatchChannelAndRequantizeNHWCAVX2(
240  const int N,
241  const int C,
242  const int HxW,
243  const fbgemm::RequantizationParams& params,
244  const T* X,
245  const int32_t* scale,
246  const int32_t* bias,
247  T* Y);
249 template <>
250 void AffineBatchChannelAndRequantizeNHWCAVX2<uint8_t>(
251  const int N,
252  const int C,
253  const int HxW,
254  const fbgemm::RequantizationParams& params,
255  const uint8_t* X,
256  const int32_t* scale,
257  const int32_t* bias,
258  uint8_t* Y) {
260  constexpr int kVLen = 8;
261  const int outer_size = N * HxW;
262 #ifdef _OPENMP
263 #pragma omp parallel for
264 #endif
265  for (int i = 0; i < outer_size; ++i) {
266  const int c = i / HxW * C;
267  const int n = C / kVLen * kVLen;
268  const int r = C % kVLen;
269  const uint8_t* X_ptr = X + i * C;
270  uint8_t* Y_ptr = Y + i * C;
271  for (int j = 0; j < n; j += kVLen) {
272  const __m256i cur_v =
273  _mm256_cvtepu8_epi32(_mm_loadl_epi64((const __m128i*)(X_ptr + j)));
274  const __m256i scale_v =
275  _mm256_loadu_si256((const __m256i*)(scale + c + j));
276  const __m256i bias_v = _mm256_loadu_si256((const __m256i*)(bias + c + j));
278  params,
279  _mm256_add_epi32(_mm256_mullo_epi32(cur_v, scale_v), bias_v),
280  (Y_ptr + j));
281  }
282  for (int j = 0; j < r; ++j) {
283  Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
284  static_cast<int32_t>(X_ptr[n + j]) * scale[c + n + j] +
285  bias[c + n + j],
286  params);
287  }
288  }
289 }
291 #undef REQUANTIZE_AVX2
294 } // namespace internal
296 } // namespace caffe2
