3 #include "caffe2/core/logging.h" 7 static double GetSaturationRegionBegin_(
double max_abs_err) {
9 double x_s = atanh(1 - max_abs_err);
11 return 1 / floor(1 / x_s);
17 static int GetPassRegionEnd_(
18 TensorQuantizationParams in_qparams,
19 TensorQuantizationParams out_qparams,
25 int in_pos_qmax = (1 << (num_in_bits - 1)) - 1;
27 float scale_multiplier = in_qparams.scale / out_qparams.scale;
28 int log2_scale_multiplier = nearbyint(log2(scale_multiplier));
31 for (x_q = 0; x_q < in_pos_qmax; ++x_q) {
33 if (log2_scale_multiplier < 0) {
34 y_q = x_q >> (-log2_scale_multiplier);
36 y_q = x_q << (log2_scale_multiplier);
38 float y = y_q * out_qparams.scale;
40 float x_min = std::max((x_q - 0.5f) * in_qparams.scale, 0.f);
41 float x_max = (x_q + 0.5f) * in_qparams.scale;
42 if (fabs(tanh(x_max) - y) > max_abs_err ||
43 fabs(tanh(x_min) - y) > max_abs_err) {
51 Tanh<T>::Tanh(
double max_abs_err) : max_abs_err_(max_abs_err) {
53 double x_sq = GetSaturationRegionBegin_(max_abs_err);
56 in_qparams_.scale = x_sq / ((1 << (num_in_bits_ - 1)) - 1);
57 in_qparams_.zero_point = 1 << (num_in_bits_ - 1);
58 in_qparams_.precision = num_in_bits_;
61 out_qparams_.scale = 1. / ((1 << (num_out_bits_ - 1)) - 1);
62 out_qparams_.zero_point = 1 << (num_out_bits_ - 1);
63 out_qparams_.precision = num_out_bits_;
68 GetPassRegionEnd_(in_qparams_, out_qparams_, max_abs_err, num_in_bits_);
70 int in_pos_qmax = (1 << (num_in_bits_ - 1)) - 1;
71 processing_region_lut_.resize(in_pos_qmax - x_pq_index_ + 2);
74 for (i = x_pq_index_; i < in_pos_qmax; ++i) {
75 double y_begin = tanh((i - 0.5) * in_qparams_.scale);
76 double y_end = tanh((i + 0.5) * in_qparams_.scale);
78 int y_avg_q = nearbyint((y_begin + y_end) / 2 / out_qparams_.scale);
79 assert(y_avg_q * out_qparams_.scale - y_begin < max_abs_err);
80 assert(y_end - y_avg_q * out_qparams_.scale < max_abs_err);
82 assert(y_avg_q < (1 << (num_out_bits_ - 1)));
83 processing_region_lut_[i - x_pq_index_] = y_avg_q;
84 #ifdef PRINT_TANH_TABLE 85 LOG(INFO) << i <<
" " << y_avg_q;
89 processing_region_lut_[i - x_pq_index_] = (1 << (num_out_bits_ - 1)) - 1;
90 #ifdef PRINT_TANH_TABLE 91 LOG(INFO) << i <<
" " << processing_region_lut_[i - x_pq_index_];
93 processing_region_lut_[i - x_pq_index_ + 1] = (1 << (num_out_bits_ - 1)) - 1;
94 #ifdef PRINT_TANH_TABLE 95 LOG(INFO) << i + 1 <<
" " << processing_region_lut_[i - x_pq_index_ + 1];
101 return (
T(0) < val) - (val <
T(0));
104 template <
typename T>
105 T Tanh<T>::Compute(
T x)
const {
106 int32_t x_adjusted = x - in_qparams_.zero_point;
107 int32_t x_sgn = sgn(x_adjusted), x_mag = std::abs(x_adjusted);
110 if (x_mag < x_pq_index_) {
112 float scale_multiplier = in_qparams_.scale / out_qparams_.scale;
113 int log2_scale_multiplier = nearbyint(log2(scale_multiplier));
114 if (log2_scale_multiplier < 0) {
115 y = x_sgn * (x_mag >> (-log2_scale_multiplier));
117 y = x_sgn * (x_mag << log2_scale_multiplier);
121 y = x_sgn * processing_region_lut_[x_mag - x_pq_index_];
124 assert(y + out_qparams_.zero_point <= std::numeric_limits<T>::max());
127 assert(y + out_qparams_.zero_point >= 0);
128 assert(y + out_qparams_.zero_point < (1 << num_out_bits_));
130 return y + out_qparams_.zero_point;
133 template class Tanh<uint8_t>;
134 template class Tanh<uint16_t>;
135 template class Tanh<int32_t>;