Caffe2 - C++ API
A deep learning, cross platform ML framework
kl_minimization.cc
1 #include "kl_minimization.h"
2 #include "caffe2/core/logging.h"
3 
4 using namespace std;
5 
6 namespace dnnlowp {
7 
8 TensorQuantizationParams KLDivergenceMinimization::ChooseQuantizationParams(
9  const Histogram& hist,
10  bool preserve_sparsity,
11  int precision) {
12  const vector<uint64_t> bins = *hist.GetHistogram();
13  int nbins = bins.size();
14  int dst_nbins = 1 << precision;
15  float min = hist.Min(), max = hist.Max();
16  assert(min <= 0.f);
17  assert(max >= 0.f);
18  double bin_width = (max - min) / nbins;
19  int zero_bin = round(-min / bin_width);
20 
21  double total_sum = 0;
22  for (int i = 0; i < nbins; ++i) {
23  total_sum += bins[i];
24  }
25 
26  vector<pair<int, double>> best_start_bins(nbins + 1);
27 
28  // Look at mapping [start_bin, start_bin + nbins_selected) to
29  // [0, 1 << precision) for every (start_bin, nbins_selected) combination and
30  // pick the one with smallest KL divergence
31 #ifdef _OPENMP
32 #pragma omp parallel for
33 #endif
34  for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
35  // if (nbins_selected % dst_nbins != 0) continue;
36  double kl_min = numeric_limits<double>::max();
37  int best_start_bin = 0;
38 
39  int start_bin_begin = 0, start_bin_end = nbins - nbins_selected + 1;
40  if (preserve_sparsity) {
41  if (min == 0) {
42  start_bin_begin = 0;
43  start_bin_end = 1;
44  } else {
45  start_bin_begin = zero_bin - nbins_selected / 2;
46  start_bin_end = start_bin_begin + 1;
47  }
48  }
49 
50  int start_bin;
51  for (start_bin = start_bin_begin; start_bin < start_bin_end; ++start_bin) {
52  double kl = 0;
53 
54  // sum outliers
55  uint64_t left_outliers = 0;
56  int src_bin;
57  for (src_bin = 0; src_bin < start_bin; ++src_bin) {
58  left_outliers += bins[src_bin];
59  }
60 
61  uint64_t right_outliers = 0;
62  for (src_bin = start_bin + nbins_selected; src_bin < nbins; ++src_bin) {
63  right_outliers += bins[src_bin];
64  }
65 
66  // each destination bin corresponds to a quantized value
67  for (int dst_bin = 0; dst_bin < dst_nbins; ++dst_bin) {
68  double non_zero_length = 0;
69  double sum = 0;
70  double src_bin_begin_not_rounded =
71  start_bin + (double)dst_bin * nbins_selected / dst_nbins;
72  int src_bin_begin = src_bin_begin_not_rounded;
73  double src_bin_end_not_rounded =
74  start_bin + (double)(dst_bin + 1) * nbins_selected / dst_nbins;
75  int src_bin_end = ceil(src_bin_end_not_rounded);
76  for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
77  if (src_bin >= 0 && src_bin < nbins) {
78  double bin = bins[src_bin];
79  double fraction = 1;
80  if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
81  fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
82  } else if (src_bin == src_bin_begin) {
83  fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
84  assert(fraction >= 0);
85  } else if (src_bin == src_bin_end - 1) {
86  fraction = src_bin_end_not_rounded - (src_bin_end - 1);
87  assert(fraction >= 0);
88  }
89  bin *= fraction;
90  sum += bin;
91 
92  if (src_bin == std::max(start_bin, 0)) {
93  bin += left_outliers;
94  }
95  if (src_bin ==
96  std::min(start_bin + nbins_selected - 1, nbins - 1)) {
97  bin += right_outliers;
98  }
99  if (bin > 0) {
100  non_zero_length += fraction;
101  }
102  }
103  } // src_bin
104 
105  for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
106  if (src_bin >= 0 && src_bin < nbins) {
107  uint64_t bin = bins[src_bin];
108  double fraction = 1;
109  if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
110  fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
111  } else if (src_bin == src_bin_begin) {
112  fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
113  } else if (src_bin == src_bin_end - 1) {
114  fraction = src_bin_end_not_rounded - (src_bin_end - 1);
115  }
116 
117  if (src_bin == std::max(start_bin, 0)) {
118  bin += left_outliers;
119  }
120  if (src_bin ==
121  std::min(start_bin + nbins_selected - 1, nbins - 1)) {
122  bin += right_outliers;
123  }
124  bin *= fraction;
125  if (bin > 0) {
126  double p = (double)bin / total_sum;
127  double q = sum * fraction / non_zero_length / total_sum;
128  kl += p * log(p / q);
129  }
130  }
131  } // src_bin
132  } // dst_bin
133 
134  assert(kl >= 0);
135  if (kl < kl_min) {
136  kl_min = kl;
137  best_start_bin = start_bin;
138  }
139  } // for each start_bin
140 
141  best_start_bins[nbins_selected] = {best_start_bin, kl_min};
142  } // for each nbins_selected
143 
144  double kl_min = numeric_limits<double>::max();
145  int best_nbins_selected = dst_nbins, best_start_bin = 0;
146  for (int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
147  double kl = best_start_bins[nbins_selected].second;
148  if (kl < kl_min) {
149  kl_min = kl;
150  best_start_bin = best_start_bins[nbins_selected].first;
151  best_nbins_selected = nbins_selected;
152  }
153  }
154 
155  double selected_sum = 0;
156  int i_begin = std::max(0, best_start_bin);
157  int i_end = std::min(nbins, best_start_bin + best_nbins_selected);
158  for (int i = i_begin; i < i_end; ++i) {
159  selected_sum += bins[i];
160  }
161  VLOG(2) << "best quantization range covers "
162  << (double)selected_sum / total_sum * 100 << " %%";
163 
164  VLOG(2) << "best start_bin " << best_start_bin << " nbins_selected "
165  << best_nbins_selected;
166 
167  min = hist.Min() + bin_width * (best_start_bin + 0.5);
168  max = hist.Min() + bin_width * (best_start_bin + best_nbins_selected + 0.5);
169 
170  QuantizationFactory* qfactory = QuantizationFactory::GetDefaultInstance();
171  return qfactory->ChooseQuantizationParams(min, max);
172 } // ChooseQuantizationParams
173 
174 } // namespace dnnlowp