13 constexpr
int VLEN = 8;
15 template <
typename T,
bool ReluFused>
16 void ElementWiseSumAVX2(
26 int32_t c_zero_point) {
27 __m256i permute_mask_v =
28 _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
30 int len_aligned = len / (VLEN * 4) * (VLEN * 4);
32 for (; j < len_aligned; j += VLEN * 4) {
36 __m256 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
37 _mm256_cvtepi8_epi32(_mm_sub_epi8(
38 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input0 + j)),
39 _mm_set1_epi8(0x80))),
40 _mm256_set1_epi32(0x80)));
41 in_v0 = _mm256_fmadd_ps(
43 _mm256_set1_ps(a_scale),
44 _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
46 __m256 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
47 _mm256_cvtepi8_epi32(_mm_sub_epi8(
48 _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input1 + j)),
49 _mm_set1_epi8(0x80))),
50 _mm256_set1_epi32(0x80)));
51 __m256 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
53 __m256 x_transformed_v = _mm256_fmadd_ps(
54 acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
57 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
58 _mm256_cvtepi8_epi32(_mm_sub_epi8(
60 reinterpret_cast<const __m128i*>(input0 + j + VLEN)),
61 _mm_set1_epi8(0x80))),
62 _mm256_set1_epi32(0x80)));
63 in_v0 = _mm256_fmadd_ps(
65 _mm256_set1_ps(a_scale),
66 _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
68 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
69 _mm256_cvtepi8_epi32(_mm_sub_epi8(
71 reinterpret_cast<const __m128i*>(input1 + j + VLEN)),
72 _mm_set1_epi8(0x80))),
73 _mm256_set1_epi32(0x80)));
74 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
76 __m256 y_transformed_v = _mm256_fmadd_ps(
77 acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
80 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
81 _mm256_cvtepi8_epi32(_mm_sub_epi8(
83 reinterpret_cast<const __m128i*>(input0 + j + 2 * VLEN)),
84 _mm_set1_epi8(0x80))),
85 _mm256_set1_epi32(0x80)));
86 in_v0 = _mm256_fmadd_ps(
88 _mm256_set1_ps(a_scale),
89 _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
91 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
92 _mm256_cvtepi8_epi32(_mm_sub_epi8(
94 reinterpret_cast<const __m128i*>(input1 + j + 2 * VLEN)),
95 _mm_set1_epi8(0x80))),
96 _mm256_set1_epi32(0x80)));
97 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
99 __m256 z_transformed_v = _mm256_fmadd_ps(
100 acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
103 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
104 _mm256_cvtepi8_epi32(_mm_sub_epi8(
106 reinterpret_cast<const __m128i*>(input0 + j + 3 * VLEN)),
107 _mm_set1_epi8(0x80))),
108 _mm256_set1_epi32(0x80)));
109 in_v0 = _mm256_fmadd_ps(
111 _mm256_set1_ps(a_scale),
112 _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
114 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
115 _mm256_cvtepi8_epi32(_mm_sub_epi8(
117 reinterpret_cast<const __m128i*>(input1 + j + 3 * VLEN)),
118 _mm_set1_epi8(0x80))),
119 _mm256_set1_epi32(0x80)));
120 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
122 __m256 w_transformed_v = _mm256_fmadd_ps(
123 acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
127 __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
128 __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
129 __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
130 __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
132 __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
133 __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
134 __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
135 __m256i xyzw_clamped_v = _mm256_max_epu8(
136 ReluFused ? _mm256_set1_epi8(c_zero_point) : _mm256_setzero_si256(),
138 xyzw_packed_v, _mm256_set1_epi8(static_cast<uint8_t>(255))));
141 _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
142 _mm256_storeu_si256(reinterpret_cast<__m256i*>(output + j), xyzw_clamped_v);
144 for (; j < len; ++j) {
146 acc += (input0[j] - a_zero_point) * a_scale;
147 acc += (input1[j] - b_zero_point) * b_scale;
148 float transformed_val = c_zero_point + acc / c_scale;
149 output[j] = std::max(
150 ReluFused ? c_zero_point : 0.0f,
151 std::min(255.0f, nearbyint(transformed_val)));
155 template void ElementWiseSumAVX2<uint8_t, false>(
156 const uint8_t* input0,
157 const uint8_t* input1,
161 int32_t a_zero_point,
163 int32_t b_zero_point,
165 int32_t c_zero_point);
167 template void ElementWiseSumAVX2<uint8_t, true>(
168 const uint8_t* input0,
169 const uint8_t* input1,
173 int32_t a_zero_point,
175 int32_t b_zero_point,
177 int32_t c_zero_point);
179 template void ElementWiseSumAVX2<uint16_t, false>(
180 const uint16_t* input0,
181 const uint16_t* input1,
185 int32_t a_zero_point,
187 int32_t b_zero_point,
189 int32_t c_zero_point);
191 template void ElementWiseSumAVX2<uint16_t, true>(
192 const uint16_t* input0,
193 const uint16_t* input1,
197 int32_t a_zero_point,
199 int32_t b_zero_point,
201 int32_t c_zero_point);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...