Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_sum_dnnlowp_op_avx2.cc
1 #include <algorithm>
2 #include <cmath>
3 #include <cstdint>
4 
5 #include <immintrin.h>
6 
7 namespace caffe2 {
8 
9 namespace internal {
10 
11 using namespace std;
12 
13 constexpr int VLEN = 8;
14 
15 template <typename T, bool ReluFused>
16 void ElementWiseSumAVX2(
17  const T* input0,
18  const T* input1,
19  T* output,
20  int len,
21  float a_scale,
22  int32_t a_zero_point,
23  float b_scale,
24  int32_t b_zero_point,
25  float c_scale,
26  int32_t c_zero_point) {
27  __m256i permute_mask_v =
28  _mm256_set_epi32(0x07, 0x03, 0x06, 0x02, 0x05, 0x01, 0x04, 0x00);
29 
30  int len_aligned = len / (VLEN * 4) * (VLEN * 4);
31  int j = 0;
32  for (; j < len_aligned; j += VLEN * 4) {
33  // Input is uint8_t but cvtepi8_epi32 assumes the input is int8_t,
34  // so we subtract 0x80, cvtepi8_epi32, and then add 0x80
35  // x
36  __m256 in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
37  _mm256_cvtepi8_epi32(_mm_sub_epi8(
38  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input0 + j)),
39  _mm_set1_epi8(0x80))),
40  _mm256_set1_epi32(0x80)));
41  in_v0 = _mm256_fmadd_ps(
42  in_v0,
43  _mm256_set1_ps(a_scale),
44  _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
45 
46  __m256 in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
47  _mm256_cvtepi8_epi32(_mm_sub_epi8(
48  _mm_loadl_epi64(reinterpret_cast<const __m128i*>(input1 + j)),
49  _mm_set1_epi8(0x80))),
50  _mm256_set1_epi32(0x80)));
51  __m256 acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
52 
53  __m256 x_transformed_v = _mm256_fmadd_ps(
54  acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
55 
56  // y
57  in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
58  _mm256_cvtepi8_epi32(_mm_sub_epi8(
59  _mm_loadl_epi64(
60  reinterpret_cast<const __m128i*>(input0 + j + VLEN)),
61  _mm_set1_epi8(0x80))),
62  _mm256_set1_epi32(0x80)));
63  in_v0 = _mm256_fmadd_ps(
64  in_v0,
65  _mm256_set1_ps(a_scale),
66  _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
67 
68  in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
69  _mm256_cvtepi8_epi32(_mm_sub_epi8(
70  _mm_loadl_epi64(
71  reinterpret_cast<const __m128i*>(input1 + j + VLEN)),
72  _mm_set1_epi8(0x80))),
73  _mm256_set1_epi32(0x80)));
74  acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
75 
76  __m256 y_transformed_v = _mm256_fmadd_ps(
77  acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
78 
79  // z
80  in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
81  _mm256_cvtepi8_epi32(_mm_sub_epi8(
82  _mm_loadl_epi64(
83  reinterpret_cast<const __m128i*>(input0 + j + 2 * VLEN)),
84  _mm_set1_epi8(0x80))),
85  _mm256_set1_epi32(0x80)));
86  in_v0 = _mm256_fmadd_ps(
87  in_v0,
88  _mm256_set1_ps(a_scale),
89  _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
90 
91  in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
92  _mm256_cvtepi8_epi32(_mm_sub_epi8(
93  _mm_loadl_epi64(
94  reinterpret_cast<const __m128i*>(input1 + j + 2 * VLEN)),
95  _mm_set1_epi8(0x80))),
96  _mm256_set1_epi32(0x80)));
97  acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
98 
99  __m256 z_transformed_v = _mm256_fmadd_ps(
100  acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
101 
102  // w
103  in_v0 = _mm256_cvtepi32_ps(_mm256_add_epi32(
104  _mm256_cvtepi8_epi32(_mm_sub_epi8(
105  _mm_loadl_epi64(
106  reinterpret_cast<const __m128i*>(input0 + j + 3 * VLEN)),
107  _mm_set1_epi8(0x80))),
108  _mm256_set1_epi32(0x80)));
109  in_v0 = _mm256_fmadd_ps(
110  in_v0,
111  _mm256_set1_ps(a_scale),
112  _mm256_set1_ps(-a_zero_point * a_scale - b_zero_point * b_scale));
113 
114  in_v1 = _mm256_cvtepi32_ps(_mm256_add_epi32(
115  _mm256_cvtepi8_epi32(_mm_sub_epi8(
116  _mm_loadl_epi64(
117  reinterpret_cast<const __m128i*>(input1 + j + 3 * VLEN)),
118  _mm_set1_epi8(0x80))),
119  _mm256_set1_epi32(0x80)));
120  acc_v = _mm256_fmadd_ps(in_v1, _mm256_set1_ps(b_scale), in_v0);
121 
122  __m256 w_transformed_v = _mm256_fmadd_ps(
123  acc_v, _mm256_set1_ps(1.0 / c_scale), _mm256_set1_ps(c_zero_point));
124 
125  // See fbgemm/src/QuantUtilsAvx2.cc requantizeOutputProcessingAvx2 function
126  // for more details on this instruction sequence
127  __m256i x_rounded_v = _mm256_cvtps_epi32(x_transformed_v);
128  __m256i y_rounded_v = _mm256_cvtps_epi32(y_transformed_v);
129  __m256i z_rounded_v = _mm256_cvtps_epi32(z_transformed_v);
130  __m256i w_rounded_v = _mm256_cvtps_epi32(w_transformed_v);
131 
132  __m256i xy_packed_v = _mm256_packs_epi32(x_rounded_v, y_rounded_v);
133  __m256i zw_packed_v = _mm256_packs_epi32(z_rounded_v, w_rounded_v);
134  __m256i xyzw_packed_v = _mm256_packus_epi16(xy_packed_v, zw_packed_v);
135  __m256i xyzw_clamped_v = _mm256_max_epu8(
136  ReluFused ? _mm256_set1_epi8(c_zero_point) : _mm256_setzero_si256(),
137  _mm256_min_epu8(
138  xyzw_packed_v, _mm256_set1_epi8(static_cast<uint8_t>(255))));
139 
140  xyzw_clamped_v =
141  _mm256_permutevar8x32_epi32(xyzw_clamped_v, permute_mask_v);
142  _mm256_storeu_si256(reinterpret_cast<__m256i*>(output + j), xyzw_clamped_v);
143  }
144  for (; j < len; ++j) {
145  float acc = 0;
146  acc += (input0[j] - a_zero_point) * a_scale;
147  acc += (input1[j] - b_zero_point) * b_scale;
148  float transformed_val = c_zero_point + acc / c_scale;
149  output[j] = std::max(
150  ReluFused ? c_zero_point : 0.0f,
151  std::min(255.0f, nearbyint(transformed_val)));
152  }
153 }
154 
155 template void ElementWiseSumAVX2<uint8_t, false>(
156  const uint8_t* input0,
157  const uint8_t* input1,
158  uint8_t* output,
159  int len,
160  float a_scale,
161  int32_t a_zero_point,
162  float b_scale,
163  int32_t b_zero_point,
164  float c_scale,
165  int32_t c_zero_point);
166 
167 template void ElementWiseSumAVX2<uint8_t, true>(
168  const uint8_t* input0,
169  const uint8_t* input1,
170  uint8_t* output,
171  int len,
172  float a_scale,
173  int32_t a_zero_point,
174  float b_scale,
175  int32_t b_zero_point,
176  float c_scale,
177  int32_t c_zero_point);
178 
179 template void ElementWiseSumAVX2<uint16_t, false>(
180  const uint16_t* input0,
181  const uint16_t* input1,
182  uint16_t* output,
183  int len,
184  float a_scale,
185  int32_t a_zero_point,
186  float b_scale,
187  int32_t b_zero_point,
188  float c_scale,
189  int32_t c_zero_point);
190 
191 template void ElementWiseSumAVX2<uint16_t, true>(
192  const uint16_t* input0,
193  const uint16_t* input1,
194  uint16_t* output,
195  int len,
196  float a_scale,
197  int32_t a_zero_point,
198  float b_scale,
199  int32_t b_zero_point,
200  float c_scale,
201  int32_t c_zero_point);
202 
203 } // namespace internal
204 
205 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13