10 float L2MinimizationKernelAVX2(
18 constexpr
int VLEN = 8;
19 float norm_delta_default = dst_bin_width * dst_bin_width * dst_bin_width / 12;
21 __m256i identity_v = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
22 __m256 bin_width_v = _mm256_set1_ps(bin_width);
23 __m256 bin_width_inverse_v = _mm256_set1_ps(1.0f / bin_width);
24 __m256 dst_bin_width_v = _mm256_set1_ps(dst_bin_width);
25 __m256 dst_bin_width_inverse_v = _mm256_set1_ps(1.0f / dst_bin_width);
26 __m256 norm_v = _mm256_setzero_ps();
29 for (; src_bin < nbins / VLEN * VLEN; src_bin += VLEN) {
33 _mm256_add_epi32(_mm256_set1_epi32(src_bin), identity_v);
34 __m256 src_bin_begin_v = _mm256_mul_ps(
36 _mm256_sub_epi32(src_bin_v, _mm256_set1_epi32(start_bin))),
38 __m256 src_bin_end_v = _mm256_add_ps(src_bin_begin_v, bin_width_v);
41 __m256i dst_bin_of_begin_v = _mm256_cvtps_epi32(_mm256_max_ps(
45 _mm256_mul_ps(src_bin_begin_v, dst_bin_width_inverse_v)),
46 _mm256_set1_ps((1 << precision) - 1.0f))));
47 __m256i dst_bin_of_end_v = _mm256_cvtps_epi32(_mm256_max_ps(
51 _mm256_mul_ps(src_bin_end_v, dst_bin_width_inverse_v)),
52 _mm256_set1_ps((1 << precision) - 1.0f))));
54 __m256 dst_bin_of_begin_center_v = _mm256_fmadd_ps(
55 _mm256_cvtepi32_ps(dst_bin_of_begin_v),
57 _mm256_set1_ps(dst_bin_width / 2));
59 __m256 density_v = _mm256_mul_ps(
60 _mm256_i32gather_ps(bins, src_bin_v, 4), bin_width_inverse_v);
61 __m256 delta_begin_v =
62 _mm256_sub_ps(src_bin_begin_v, dst_bin_of_begin_center_v);
63 __m256 norm_delta_v = _mm256_mul_ps(
65 _mm256_mul_ps(delta_begin_v, delta_begin_v), delta_begin_v),
66 _mm256_set1_ps(-1.0f / 3));
67 __m256i mask_v = _mm256_cmpeq_epi32(dst_bin_of_begin_v, dst_bin_of_end_v);
70 _mm256_sub_ps(src_bin_end_v, dst_bin_of_begin_center_v);
72 __m256 dst_bin_of_end_center_v = _mm256_fmadd_ps(
73 _mm256_cvtepi32_ps(dst_bin_of_end_v),
75 _mm256_set1_ps(dst_bin_width / 2));
76 __m256 delta_end1_v = _mm256_sub_ps(src_bin_end_v, dst_bin_of_end_center_v);
77 __m256 delta_end_v = _mm256_blendv_ps(
78 delta_end1_v, delta_end0_v, _mm256_castsi256_ps(mask_v));
79 norm_delta_v = _mm256_fmadd_ps(
80 _mm256_mul_ps(_mm256_mul_ps(delta_end_v, delta_end_v), delta_end_v),
81 _mm256_set1_ps(1.0f / 3),
84 norm_delta_v = _mm256_fmadd_ps(
86 _mm256_sub_epi32(dst_bin_of_end_v, dst_bin_of_begin_v)),
87 _mm256_set1_ps(norm_delta_default),
90 norm_v = _mm256_fmadd_ps(density_v, norm_delta_v, norm_v);
93 _mm256_storeu_ps(norm_buf, norm_v);
94 for (
int i = 0; i < VLEN; ++i) {
98 for (; src_bin < nbins; ++src_bin) {
101 float src_bin_begin = (src_bin - start_bin) * bin_width;
102 float src_bin_end = src_bin_begin + bin_width;
105 int dst_bin_of_begin = std::min(
106 (1 << precision) - 1.0f,
107 std::max(0.0f, floorf(src_bin_begin / dst_bin_width)));
108 int dst_bin_of_end = std::min(
109 (1 << precision) - 1.0f,
110 std::max(0.0f, floorf(src_bin_end / dst_bin_width)));
112 float dst_bin_of_begin_center =
113 dst_bin_of_begin * dst_bin_width + dst_bin_width / 2;
114 float density = bins[src_bin] / bin_width;
115 float delta_begin = src_bin_begin - dst_bin_of_begin_center;
116 float norm_delta = -(delta_begin * delta_begin * delta_begin) / 3;
117 if (dst_bin_of_begin == dst_bin_of_end) {
119 float delta_end = src_bin_end - dst_bin_of_begin_center;
120 norm_delta += (delta_end * delta_end * delta_end) / 3;
122 norm_delta += (dst_bin_of_end - dst_bin_of_begin) * norm_delta_default;
124 float dst_bin_of_end_center =
125 dst_bin_of_end * dst_bin_width + dst_bin_width / 2;
126 float delta_end = src_bin_end - dst_bin_of_end_center;
127 norm_delta += (delta_end * delta_end * delta_end) / 3;
129 norm += density * norm_delta;