1 #include "kl_minimization.h" 2 #include "caffe2/core/logging.h" 8 TensorQuantizationParams KLDivergenceMinimization::ChooseQuantizationParams(
10 bool preserve_sparsity,
12 const vector<uint64_t> bins = *hist.GetHistogram();
13 int nbins = bins.size();
14 int dst_nbins = 1 << precision;
15 float min = hist.Min(), max = hist.Max();
18 double bin_width = (max - min) / nbins;
19 int zero_bin = round(-min / bin_width);
22 for (
int i = 0; i < nbins; ++i) {
26 vector<pair<int, double>> best_start_bins(nbins + 1);
32 #pragma omp parallel for 34 for (
int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
36 double kl_min = numeric_limits<double>::max();
37 int best_start_bin = 0;
39 int start_bin_begin = 0, start_bin_end = nbins - nbins_selected + 1;
40 if (preserve_sparsity) {
45 start_bin_begin = zero_bin - nbins_selected / 2;
46 start_bin_end = start_bin_begin + 1;
51 for (start_bin = start_bin_begin; start_bin < start_bin_end; ++start_bin) {
55 uint64_t left_outliers = 0;
57 for (src_bin = 0; src_bin < start_bin; ++src_bin) {
58 left_outliers += bins[src_bin];
61 uint64_t right_outliers = 0;
62 for (src_bin = start_bin + nbins_selected; src_bin < nbins; ++src_bin) {
63 right_outliers += bins[src_bin];
67 for (
int dst_bin = 0; dst_bin < dst_nbins; ++dst_bin) {
68 double non_zero_length = 0;
70 double src_bin_begin_not_rounded =
71 start_bin + (double)dst_bin * nbins_selected / dst_nbins;
72 int src_bin_begin = src_bin_begin_not_rounded;
73 double src_bin_end_not_rounded =
74 start_bin + (double)(dst_bin + 1) * nbins_selected / dst_nbins;
75 int src_bin_end = ceil(src_bin_end_not_rounded);
76 for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
77 if (src_bin >= 0 && src_bin < nbins) {
78 double bin = bins[src_bin];
80 if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
81 fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
82 }
else if (src_bin == src_bin_begin) {
83 fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
84 assert(fraction >= 0);
85 }
else if (src_bin == src_bin_end - 1) {
86 fraction = src_bin_end_not_rounded - (src_bin_end - 1);
87 assert(fraction >= 0);
92 if (src_bin == std::max(start_bin, 0)) {
96 std::min(start_bin + nbins_selected - 1, nbins - 1)) {
97 bin += right_outliers;
100 non_zero_length += fraction;
105 for (src_bin = src_bin_begin; src_bin < src_bin_end; ++src_bin) {
106 if (src_bin >= 0 && src_bin < nbins) {
107 uint64_t bin = bins[src_bin];
109 if (src_bin == src_bin_begin && src_bin == src_bin_end - 1) {
110 fraction = src_bin_end_not_rounded - src_bin_begin_not_rounded;
111 }
else if (src_bin == src_bin_begin) {
112 fraction = (src_bin_begin + 1) - src_bin_begin_not_rounded;
113 }
else if (src_bin == src_bin_end - 1) {
114 fraction = src_bin_end_not_rounded - (src_bin_end - 1);
117 if (src_bin == std::max(start_bin, 0)) {
118 bin += left_outliers;
121 std::min(start_bin + nbins_selected - 1, nbins - 1)) {
122 bin += right_outliers;
126 double p = (double)bin / total_sum;
127 double q = sum * fraction / non_zero_length / total_sum;
128 kl += p * log(p / q);
137 best_start_bin = start_bin;
141 best_start_bins[nbins_selected] = {best_start_bin, kl_min};
144 double kl_min = numeric_limits<double>::max();
145 int best_nbins_selected = dst_nbins, best_start_bin = 0;
146 for (
int nbins_selected = 1; nbins_selected <= nbins; ++nbins_selected) {
147 double kl = best_start_bins[nbins_selected].second;
150 best_start_bin = best_start_bins[nbins_selected].first;
151 best_nbins_selected = nbins_selected;
155 double selected_sum = 0;
156 int i_begin = std::max(0, best_start_bin);
157 int i_end = std::min(nbins, best_start_bin + best_nbins_selected);
158 for (
int i = i_begin; i < i_end; ++i) {
159 selected_sum += bins[i];
161 VLOG(2) <<
"best quantization range covers " 162 << (double)selected_sum / total_sum * 100 <<
" %%";
164 VLOG(2) <<
"best start_bin " << best_start_bin <<
" nbins_selected " 165 << best_nbins_selected;
167 min = hist.Min() + bin_width * (best_start_bin + 0.5);
168 max = hist.Min() + bin_width * (best_start_bin + best_nbins_selected + 0.5);
170 QuantizationFactory* qfactory = QuantizationFactory::GetDefaultInstance();
171 return qfactory->ChooseQuantizationParams(min, max);