1 #include "lstm_unit_dnnlowp_op.h" 3 #include "caffe2/core/tensor_int8.h" 4 #include "caffe2/quantization/server/dnnlowp.h" 5 #include "caffe2/quantization/server/sigmoid.h" 6 #include "caffe2/quantization/server/tanh.h" 14 LSTMUnitDNNLowPOp<T>::LSTMUnitDNNLowPOp(
15 const OperatorDef& operator_def,
17 : LSTMUnitOp<CPUContext>(operator_def, ws),
19 this->template GetSingleArgument<bool>(
"drop_states", false)),
20 qfactory_(GetQuantizationFactoryOf(this)) {}
23 LSTMUnitDNNLowPOp<T>::~LSTMUnitDNNLowPOp() {
24 if (measure_quantization_error_) {
25 ReportQuantizationError(
this, cell_quantization_error_stats_);
26 ReportQuantizationError(
this, hidden_quantization_error_stats_);
31 OpWrapper<LSTMUnitOp<CPUContext>,
T>* LSTMUnitDNNLowPOp<T>::Fp32Op_() {
34 new OpWrapper<LSTMUnitOp<CPUContext>,
T>(
this, qfactory_.get()));
36 return fp32_op_.get();
40 const TensorCPU& LSTMUnitDNNLowPOp<T>::InputTensorCPU_(
int idx) {
41 return InputIsType<int8::Int8TensorCPU>(idx)
42 ? this->
template Input<int8::Int8TensorCPU>(idx).t
47 TensorCPU* LSTMUnitDNNLowPOp<T>::OutputTensorCPU_(
int idx) {
48 if (dequantize_output_) {
51 return &Outputs()[idx]->template GetMutable<int8::Int8TensorCPU>()->t;
63 const int32_t* seqLengths,
67 const int32_t forget_bias,
70 const TensorQuantizationParams& X_qparams,
71 const TensorQuantizationParams& C_in_qparams,
72 const TensorQuantizationParams& C_out_qparams,
73 const TensorQuantizationParams& H_in_qparams,
74 const TensorQuantizationParams& H_out_qparams,
76 const TensorQuantizationParams sigmoid_in_qparams =
77 sigmoid.GetInputQuantizationParams();
78 const TensorQuantizationParams sigmoid_out_qparams =
79 sigmoid.GetOutputQuantizationParams();
80 const TensorQuantizationParams tanh_in_qparams =
81 tanh.GetInputQuantizationParams();
82 const TensorQuantizationParams tanh_out_qparams =
83 tanh.GetOutputQuantizationParams();
85 RequantizationParams h_in_to_out_params =
86 qfactory->ChooseRequantizationMultiplier(
87 H_in_qparams.scale / H_out_qparams.scale, H_out_qparams);
89 RequantizationParams c_in_to_out_params =
90 qfactory->ChooseRequantizationMultiplier(
91 C_in_qparams.scale / C_out_qparams.scale, C_out_qparams);
93 float sigmoid_scale = sigmoid_out_qparams.scale;
94 float tanh_scale = tanh_out_qparams.scale;
95 int32_t sigmoid_zero_point = sigmoid_out_qparams.zero_point;
96 int32_t tanh_zero_point = tanh_out_qparams.zero_point;
98 RequantizationParams x_to_sigmoid_params =
99 qfactory->ChooseRequantizationMultiplier(
100 X_qparams.scale / sigmoid_in_qparams.scale, sigmoid_in_qparams);
102 RequantizationParams x_to_tanh_params =
103 qfactory->ChooseRequantizationMultiplier(
104 X_qparams.scale / tanh_in_qparams.scale, tanh_in_qparams);
106 RequantizationParams c_to_tanh_params =
107 qfactory->ChooseRequantizationMultiplier(
108 C_in_qparams.scale / tanh_scale, tanh_out_qparams);
110 RequantizationParams c_out_requantization_params =
111 qfactory->ChooseRequantizationMultiplier(
112 sigmoid_scale * tanh_scale / C_out_qparams.scale, C_out_qparams);
114 RequantizationParams c_tanh_requantization_params =
115 qfactory->ChooseRequantizationMultiplier(
116 sigmoid_scale * tanh_scale / tanh_in_qparams.scale, tanh_in_qparams);
118 RequantizationParams h_requantization_params =
119 qfactory->ChooseRequantizationMultiplier(
120 sigmoid_scale * tanh_scale / H_out_qparams.scale, H_out_qparams);
122 for (
int n = 0; n < N; ++n) {
123 const bool valid = t < seqLengths[n];
125 for (
int d = 0; d < D; ++d) {
128 H[d] = H_out_qparams.zero_point;
129 C[d] = C_out_qparams.zero_point;
131 H[d] = fbgemm::Requantize<T>(
132 H_prev[d] - H_in_qparams.zero_point, h_in_to_out_params);
133 C[d] = fbgemm::Requantize<T>(
134 C_prev[d] - C_in_qparams.zero_point, c_in_to_out_params);
137 T i_in = fbgemm::Requantize<T>(
138 X[d] - X_qparams.zero_point, x_to_sigmoid_params);
139 T f_in = fbgemm::Requantize<T>(
140 X[1 * D + d] + forget_bias - 2 * X_qparams.zero_point,
141 x_to_sigmoid_params);
142 T o_in = fbgemm::Requantize<T>(
143 X[2 * D + d] - X_qparams.zero_point, x_to_sigmoid_params);
144 T g_in = fbgemm::Requantize<T>(
145 X[3 * D + d] - X_qparams.zero_point, x_to_tanh_params);
147 const T i = sigmoid.Compute(i_in);
148 const T f = sigmoid.Compute(f_in);
149 const T o = sigmoid.Compute(o_in);
150 const T g = tanh.Compute(g_in);
151 const T c_prev = C_prev[d];
154 int32_t f_times_c_prev = ((int32_t)f - sigmoid_zero_point) *
155 ((int32_t)c_prev - C_in_qparams.zero_point);
159 ((int32_t)i - sigmoid_zero_point) * ((int32_t)g - tanh_zero_point);
162 int32_t f_times_c_prev_rescaled = fbgemm::Requantize<int32_t>(
165 c_to_tanh_params.real_multiplier,
168 int32_t c_temp = f_times_c_prev_rescaled + i_times_g;
171 C[d] = fbgemm::Requantize<T>(c_temp, c_out_requantization_params);
174 fbgemm::Requantize<T>(c_temp, c_tanh_requantization_params);
175 T host_tanh_c = tanh.Compute(c_tanh_input);
178 int32_t o_times_host_tanh_c = ((int32_t)o - sigmoid_zero_point) *
179 ((int32_t)host_tanh_c - tanh_zero_point);
181 fbgemm::Requantize<T>(o_times_host_tanh_c, h_requantization_params);
192 template <
typename T>
193 bool LSTMUnitDNNLowPOp<T>::GetQuantizationParameters_() {
197 GetInputTensorQuantizationParamsOf(
this, HIDDEN_T_M_1, qfactory_.get());
199 GetInputTensorQuantizationParamsOf(
this, CELL_T_M_1, qfactory_.get());
202 G_in_qparams_ = qfactory_->ChooseQuantizationParams(
204 sigmoid_.GetInputQuantizationParams().Min(),
205 tanh_.GetInputQuantizationParams().Min()),
207 sigmoid_.GetInputQuantizationParams().Max(),
208 tanh_.GetInputQuantizationParams().Max()));
210 if (HasStaticQuantization(
this, HIDDEN_T)) {
211 H_out_qparams_ = GetStaticQuantizationParamsOf(
this, HIDDEN_T);
213 if (HasStaticQuantization(
this, CELL_T)) {
214 C_out_qparams_ = GetStaticQuantizationParamsOf(
this, CELL_T);
217 if (!HasStaticQuantization(
this, HIDDEN_T) ||
218 !HasStaticQuantization(
this, CELL_T)) {
219 Fp32Op_()->DequantizeInput();
220 if (!Fp32Op_()->Get()->RunOnDevice()) {
223 if (!HasStaticQuantization(
this, HIDDEN_T)) {
225 Fp32Op_()->GetOutputQuantizationParams(qfactory_.get(), HIDDEN_T);
227 if (!HasStaticQuantization(
this, CELL_T)) {
229 Fp32Op_()->GetOutputQuantizationParams(qfactory_.get(), CELL_T);
236 template <
typename T>
237 bool LSTMUnitDNNLowPOp<T>::RunOnDevice() {
238 if (!arguments_parsed_) {
239 ParseDNNLowPOperatorArguments(
240 this, &dequantize_output_, &measure_quantization_error_);
241 arguments_parsed_ =
true;
244 GetQuantizationParameters_();
247 const auto N = InputTensorCPU_(CELL_T_M_1).size(1);
250 const auto G = InputTensorCPU_(GATES).size(2);
251 const auto D = InputTensorCPU_(CELL_T_M_1).size(2);
253 CAFFE_ENFORCE_EQ(4 * D, G);
256 vector<T> H_prev_temp;
258 QuantizeInputIfNeeded(
this, HIDDEN_T_M_1, H_in_qparams_, H_prev_temp);
261 vector<T> C_prev_temp;
263 QuantizeInputIfNeeded(
this, CELL_T_M_1, C_in_qparams_, C_prev_temp);
267 const T* X = QuantizeInputIfNeeded(
this, GATES, G_in_qparams_, X_temp);
270 const size_t TIMESTEP = SEQ_LENGTHS + 1;
272 CAFFE_ENFORCE_EQ(Input(SEQ_LENGTHS).size(), N);
273 const auto* seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
274 const auto t =
static_cast<OperatorBase*
>(
this)
275 ->Input<Tensor>(TIMESTEP, CPU)
276 .template data<int32_t>()[0];
277 OutputTensorCPU_(CELL_T)->ResizeLike(InputTensorCPU_(CELL_T_M_1));
278 OutputTensorCPU_(HIDDEN_T)->ResizeLike(InputTensorCPU_(CELL_T_M_1));
280 vector<uint8_t> Ctemp, Htemp;
281 uint8_t *Cdata, *Hdata;
282 if (dequantize_output_) {
283 Ctemp.resize(OutputTensorCPU_(CELL_T)->size());
284 Cdata = Ctemp.data();
286 Htemp.resize(OutputTensorCPU_(HIDDEN_T)->size());
287 Hdata = Htemp.data();
289 Cdata = OutputTensorCPU_(CELL_T)->template mutable_data<uint8_t>();
290 Hdata = OutputTensorCPU_(HIDDEN_T)->template mutable_data<uint8_t>();
293 int32_t forget_bias_quantized =
294 fbgemm::Quantize<int32_t>(forget_bias_, G_in_qparams_);
307 forget_bias_quantized,
317 if (dequantize_output_) {
318 fbgemm::Dequantize<T>(
320 OutputTensorCPU_(CELL_T)->template mutable_data<float>(),
323 fbgemm::Dequantize<T>(
325 OutputTensorCPU_(HIDDEN_T)->template mutable_data<float>(),
329 if (measure_quantization_error_) {
330 MeasureQuantizationError(
331 OutputTensorCPU_(CELL_T)->
template mutable_data<float>(),
332 Fp32Op_()->Get()->Output(CELL_T)->
template data<float>(),
333 OutputTensorCPU_(CELL_T)->size(),
334 &cell_quantization_error_stats_);
336 MeasureQuantizationError(
337 OutputTensorCPU_(HIDDEN_T)->
template mutable_data<float>(),
338 Fp32Op_()->Get()->Output(HIDDEN_T)->
template data<float>(),
339 OutputTensorCPU_(HIDDEN_T)->size(),
340 &hidden_quantization_error_stats_);
343 PropagateOutputTensorQuantizationParams(
this, HIDDEN_T, H_out_qparams_);
344 PropagateOutputTensorQuantizationParams(
this, CELL_T, C_out_qparams_);
350 REGISTER_CPU_OPERATOR_WITH_ENGINE(
353 LSTMUnitDNNLowPOp<uint8_t>);
354 REGISTER_CPU_OPERATOR_WITH_ENGINE(
357 LSTMUnitDNNLowPOp<uint8_t>);
sigmoid(x) = (tanh(x/2) + 1)/2 Quantized sigmoid is computed as tanh under the hood, we just use different input/output quantization parameters.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
We use the 3-region approach described in "Efficient VLSI Implementation of Neural Networks with Hype...