10 #include <fbgemm/QuantUtils.h> 17 void SegmentMomentsAVX2(
24 void SegmentMomentsAVX2<uint8_t>(
29 constexpr
int kVLen = 16;
30 const int n = N / kVLen * kVLen;
31 const int r = N % kVLen;
32 const __m256i kOneInt16 = _mm256_set1_epi16(0x01);
33 __m256i sum_v = _mm256_setzero_si256();
34 __m256i sumsq_v = _mm256_setzero_si256();
35 for (
int i = 0; i < n; i += kVLen) {
36 const __m256i cur_v = _mm256_cvtepu8_epi16(
37 _mm_loadu_si128(reinterpret_cast<const __m128i*>(src + i)));
38 sum_v = _mm256_add_epi32(sum_v, _mm256_madd_epi16(cur_v, kOneInt16));
39 sumsq_v = _mm256_add_epi32(sumsq_v, _mm256_madd_epi16(cur_v, cur_v));
43 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sum_arr), sum_v);
44 _mm256_storeu_si256(reinterpret_cast<__m256i*>(sumsq_arr), sumsq_v);
45 for (
int i = 0; i < 8; ++i) {
46 *sum +=
static_cast<int64_t
>(sum_arr[i]);
47 *sumsq +=
static_cast<int64_t
>(sumsq_arr[i]);
49 for (
int i = 0; i < r; ++i) {
50 *sum +=
static_cast<int64_t
>(src[n + i]);
52 static_cast<int64_t
>(src[n + i]) * static_cast<int64_t>(src[n + i]);
57 void VectorMomentsAVX2(
const int N,
const T* src, int64_t* sum, int64_t* sumsq);
60 void VectorMomentsAVX2<uint8_t>(
65 constexpr
int kVLen = 32768;
66 const int n = N / kVLen * kVLen;
67 const int r = N % kVLen;
68 for (
int i = 0; i < n; i += kVLen) {
69 SegmentMomentsAVX2<uint8_t>(kVLen, src + i, sum, sumsq);
72 SegmentMomentsAVX2<uint8_t>(r, src + n, sum, sumsq);
76 void ComputeQuantizedFusedParamsAVX2(
80 const int32_t X_zero_point,
86 constexpr
int kVLen = 8;
87 const int k = K / kVLen * kVLen;
88 const int r = K % kVLen;
89 for (
int n = N - 1; n >= 0; --n) {
91 #pragma omp parallel for 93 for (
int g = 0; g < G; ++g) {
94 const __m256i mu_v = _mm256_set1_epi32(mu[n * G + g] + X_zero_point);
95 const __m256i rsig_v = _mm256_set1_epi32(rsig[n * G + g]);
96 for (
int i = 0; i < k; i += kVLen) {
97 const __m256i gamma_v =
98 _mm256_loadu_si256((
const __m256i*)(gamma + g * K + i));
99 const __m256i beta_v =
100 _mm256_loadu_si256((
const __m256i*)(bias + g * K + i));
101 __m256i scale_v = _mm256_mullo_epi32(gamma_v, rsig_v);
103 _mm256_sub_epi32(beta_v, _mm256_mullo_epi32(scale_v, mu_v));
104 const int offset = (n * G + g) * K + i;
105 _mm256_storeu_si256((__m256i*)(scale + offset), scale_v);
106 _mm256_storeu_si256((__m256i*)(bias + offset), bias_v);
108 for (
int i = 0; i < r; ++i) {
109 const int offset = (n * G + g) * K + k + i;
110 scale[offset] = gamma[g * K + k + i] * rsig[n * G + g];
111 bias[offset] = bias[g * K + k + i] -
112 scale[offset] * (mu[n * G + g] + X_zero_point);
118 #define INIT_REQUANTIZE_AVX2 \ 119 const __m256i b = _mm256_set1_epi32(params.multiplier); \ 120 const __m256i prev_shift_nudge = _mm256_set1_epi64x( \ 121 (1ll << (params.right_shift - 1)) + 0x8000000000000000ULL); \ 122 const __m256i post_shift_nudge = _mm256_set1_epi64x( \ 123 params.target_qparams.zero_point - \ 124 (0x8000000000000000ULL >> params.right_shift)); \ 125 const __m256i min_v = \ 126 _mm256_set1_epi32(std::numeric_limits<uint8_t>::min()); \ 127 const __m256i max_v = \ 128 _mm256_set1_epi32(std::numeric_limits<uint8_t>::max()); \ 129 const __m256i shuffle_mask_v = _mm256_set_epi8( \ 162 const __m256i permute_mask_v = \ 163 _mm256_set_epi32(0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00); 165 #define REQUANTIZE_AVX2(params, src, dst) \ 167 __m256i a_v = (src); \ 168 __m256i a_even_v = a_v; \ 169 __m256i a_odd_v = _mm256_srli_si256(a_v, 4); \ 170 __m256i ab_even_v = _mm256_mul_epi32(a_even_v, b); \ 171 __m256i ab_odd_v = _mm256_mul_epi32(a_odd_v, b); \ 172 __m256i even_rounded_v = _mm256_add_epi64(ab_even_v, prev_shift_nudge); \ 173 __m256i odd_rounded_v = _mm256_add_epi64(ab_odd_v, prev_shift_nudge); \ 174 __m256i even_result_v = _mm256_add_epi64( \ 175 _mm256_srli_epi64(even_rounded_v, params.right_shift), \ 177 __m256i odd_result_v = _mm256_add_epi64( \ 178 _mm256_srli_epi64(odd_rounded_v, params.right_shift), \ 180 odd_result_v = _mm256_slli_si256(odd_result_v, 4); \ 181 __m256i result_v = _mm256_blend_epi32(even_result_v, odd_result_v, 0xaa); \ 182 __m256i clipped_v = \ 183 _mm256_max_epi32(min_v, _mm256_min_epi32(max_v, result_v)); \ 184 clipped_v = _mm256_shuffle_epi8(clipped_v, shuffle_mask_v); \ 185 clipped_v = _mm256_permutevar8x32_epi32(clipped_v, permute_mask_v); \ 186 *(int64_t*)(dst) = _mm256_extract_epi64(clipped_v, 0); \ 189 template <
typename T>
190 void AffineBatchChannelAndRequantizeNCHWAVX2(
194 const fbgemm::RequantizationParams& params,
196 const int32_t* scale,
201 void AffineBatchChannelAndRequantizeNCHWAVX2<uint8_t>(
205 const fbgemm::RequantizationParams& params,
207 const int32_t* scale,
210 INIT_REQUANTIZE_AVX2;
211 constexpr
int kVLen = 8;
212 const int outer_size = N * C;
213 const int n = HxW / kVLen * kVLen;
214 const int r = HxW % kVLen;
216 #pragma omp parallel for 218 for (
int i = 0; i < outer_size; ++i) {
219 const uint8_t* X_ptr = X + i * HxW;
220 uint8_t* Y_ptr = Y + i * HxW;
221 const __m256i scale_v = _mm256_set1_epi32(scale[i]);
222 const __m256i bias_v = _mm256_set1_epi32(bias[i]);
223 for (
int j = 0; j < n; j += kVLen) {
224 const __m256i cur_v =
225 _mm256_cvtepu8_epi32(_mm_loadl_epi64((
const __m128i*)(X_ptr + j)));
228 _mm256_add_epi32(_mm256_mullo_epi32(cur_v, scale_v), bias_v),
231 for (
int j = 0; j < r; ++j) {
232 Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
233 static_cast<int32_t
>(X_ptr[n + j]) * scale[i] + bias[i], params);
238 template <
typename T>
239 void AffineBatchChannelAndRequantizeNHWCAVX2(
243 const fbgemm::RequantizationParams& params,
245 const int32_t* scale,
250 void AffineBatchChannelAndRequantizeNHWCAVX2<uint8_t>(
254 const fbgemm::RequantizationParams& params,
256 const int32_t* scale,
259 INIT_REQUANTIZE_AVX2;
260 constexpr
int kVLen = 8;
261 const int outer_size = N * HxW;
263 #pragma omp parallel for 265 for (
int i = 0; i < outer_size; ++i) {
266 const int c = i / HxW * C;
267 const int n = C / kVLen * kVLen;
268 const int r = C % kVLen;
269 const uint8_t* X_ptr = X + i * C;
270 uint8_t* Y_ptr = Y + i * C;
271 for (
int j = 0; j < n; j += kVLen) {
272 const __m256i cur_v =
273 _mm256_cvtepu8_epi32(_mm_loadl_epi64((
const __m128i*)(X_ptr + j)));
274 const __m256i scale_v =
275 _mm256_loadu_si256((
const __m256i*)(scale + c + j));
276 const __m256i bias_v = _mm256_loadu_si256((
const __m256i*)(bias + c + j));
279 _mm256_add_epi32(_mm256_mullo_epi32(cur_v, scale_v), bias_v),
282 for (
int j = 0; j < r; ++j) {
283 Y_ptr[n + j] = fbgemm::Requantize<uint8_t>(
284 static_cast<int32_t
>(X_ptr[n + j]) * scale[c + n + j] +
291 #undef REQUANTIZE_AVX2 292 #undef INIT_REQUANTIZE_AVX2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...