Caffe2 - C++ API
A deep learning, cross platform ML framework
norm_minimization_avx2.cc
1 #include <algorithm>
2 #include <cmath>
3 
4 #include <immintrin.h>
5 
6 namespace dnnlowp {
7 
8 namespace internal {
9 
10 float L2MinimizationKernelAVX2(
11  int precision,
12  float* bins,
13  int nbins,
14  float bin_width,
15  float dst_bin_width,
16  int start_bin) {
17  float norm = 0;
18  constexpr int VLEN = 8;
19  float norm_delta_default = dst_bin_width * dst_bin_width * dst_bin_width / 12;
20 
21  __m256i identity_v = _mm256_set_epi32(7, 6, 5, 4, 3, 2, 1, 0);
22  __m256 bin_width_v = _mm256_set1_ps(bin_width);
23  __m256 bin_width_inverse_v = _mm256_set1_ps(1.0f / bin_width);
24  __m256 dst_bin_width_v = _mm256_set1_ps(dst_bin_width);
25  __m256 dst_bin_width_inverse_v = _mm256_set1_ps(1.0f / dst_bin_width);
26  __m256 norm_v = _mm256_setzero_ps();
27 
28  int src_bin = 0;
29  for (; src_bin < nbins / VLEN * VLEN; src_bin += VLEN) {
30  // distances from the beginning of first dst_bin to the beginning and
31  // end of src_bin
32  __m256i src_bin_v =
33  _mm256_add_epi32(_mm256_set1_epi32(src_bin), identity_v);
34  __m256 src_bin_begin_v = _mm256_mul_ps(
35  _mm256_cvtepi32_ps(
36  _mm256_sub_epi32(src_bin_v, _mm256_set1_epi32(start_bin))),
37  bin_width_v);
38  __m256 src_bin_end_v = _mm256_add_ps(src_bin_begin_v, bin_width_v);
39 
40  // which dst_bins the beginning and end of src_bin belong to?
41  __m256i dst_bin_of_begin_v = _mm256_cvtps_epi32(_mm256_max_ps(
42  _mm256_setzero_ps(),
43  _mm256_min_ps(
44  _mm256_floor_ps(
45  _mm256_mul_ps(src_bin_begin_v, dst_bin_width_inverse_v)),
46  _mm256_set1_ps((1 << precision) - 1.0f))));
47  __m256i dst_bin_of_end_v = _mm256_cvtps_epi32(_mm256_max_ps(
48  _mm256_setzero_ps(),
49  _mm256_min_ps(
50  _mm256_floor_ps(
51  _mm256_mul_ps(src_bin_end_v, dst_bin_width_inverse_v)),
52  _mm256_set1_ps((1 << precision) - 1.0f))));
53 
54  __m256 dst_bin_of_begin_center_v = _mm256_fmadd_ps(
55  _mm256_cvtepi32_ps(dst_bin_of_begin_v),
56  dst_bin_width_v,
57  _mm256_set1_ps(dst_bin_width / 2));
58  // Using sizeof(float) instead of 4 generates compilation error in dbg mode.
59  __m256 density_v = _mm256_mul_ps(
60  _mm256_i32gather_ps(bins, src_bin_v, 4), bin_width_inverse_v);
61  __m256 delta_begin_v =
62  _mm256_sub_ps(src_bin_begin_v, dst_bin_of_begin_center_v);
63  __m256 norm_delta_v = _mm256_mul_ps(
64  _mm256_mul_ps(
65  _mm256_mul_ps(delta_begin_v, delta_begin_v), delta_begin_v),
66  _mm256_set1_ps(-1.0f / 3));
67  __m256i mask_v = _mm256_cmpeq_epi32(dst_bin_of_begin_v, dst_bin_of_end_v);
68 
69  __m256 delta_end0_v =
70  _mm256_sub_ps(src_bin_end_v, dst_bin_of_begin_center_v);
71 
72  __m256 dst_bin_of_end_center_v = _mm256_fmadd_ps(
73  _mm256_cvtepi32_ps(dst_bin_of_end_v),
74  dst_bin_width_v,
75  _mm256_set1_ps(dst_bin_width / 2));
76  __m256 delta_end1_v = _mm256_sub_ps(src_bin_end_v, dst_bin_of_end_center_v);
77  __m256 delta_end_v = _mm256_blendv_ps(
78  delta_end1_v, delta_end0_v, _mm256_castsi256_ps(mask_v));
79  norm_delta_v = _mm256_fmadd_ps(
80  _mm256_mul_ps(_mm256_mul_ps(delta_end_v, delta_end_v), delta_end_v),
81  _mm256_set1_ps(1.0f / 3),
82  norm_delta_v);
83 
84  norm_delta_v = _mm256_fmadd_ps(
85  _mm256_cvtepi32_ps(
86  _mm256_sub_epi32(dst_bin_of_end_v, dst_bin_of_begin_v)),
87  _mm256_set1_ps(norm_delta_default),
88  norm_delta_v);
89 
90  norm_v = _mm256_fmadd_ps(density_v, norm_delta_v, norm_v);
91  } // src_bin loop vectorized
92  float norm_buf[VLEN];
93  _mm256_storeu_ps(norm_buf, norm_v);
94  for (int i = 0; i < VLEN; ++i) {
95  norm += norm_buf[i];
96  }
97 
98  for (; src_bin < nbins; ++src_bin) {
99  // distances from the beginning of first dst_bin to the beginning and
100  // end of src_bin
101  float src_bin_begin = (src_bin - start_bin) * bin_width;
102  float src_bin_end = src_bin_begin + bin_width;
103 
104  // which dst_bins the beginning and end of src_bin belong to?
105  int dst_bin_of_begin = std::min(
106  (1 << precision) - 1.0f,
107  std::max(0.0f, floorf(src_bin_begin / dst_bin_width)));
108  int dst_bin_of_end = std::min(
109  (1 << precision) - 1.0f,
110  std::max(0.0f, floorf(src_bin_end / dst_bin_width)));
111 
112  float dst_bin_of_begin_center =
113  dst_bin_of_begin * dst_bin_width + dst_bin_width / 2;
114  float density = bins[src_bin] / bin_width;
115  float delta_begin = src_bin_begin - dst_bin_of_begin_center;
116  float norm_delta = -(delta_begin * delta_begin * delta_begin) / 3;
117  if (dst_bin_of_begin == dst_bin_of_end) {
118  // if src_bin is entirely within 1 dst_bin
119  float delta_end = src_bin_end - dst_bin_of_begin_center;
120  norm_delta += (delta_end * delta_end * delta_end) / 3;
121  } else {
122  norm_delta += (dst_bin_of_end - dst_bin_of_begin) * norm_delta_default;
123 
124  float dst_bin_of_end_center =
125  dst_bin_of_end * dst_bin_width + dst_bin_width / 2;
126  float delta_end = src_bin_end - dst_bin_of_end_center;
127  norm_delta += (delta_end * delta_end * delta_end) / 3;
128  }
129  norm += density * norm_delta;
130  } // src_bin loop remainder
131 
132  return norm;
133 }
134 
135 } // namespace internal
136 
137 } // namespace dnnlowp