Caffe2 - C++ API
A deep learning, cross platform ML framework
lstm_unit_dnnlowp_op.h
1 #pragma once
2 
3 #include "caffe2/operators/lstm_unit_op.h"
4 #include "caffe2/quantization/server/caffe2_dnnlowp_utils.h"
5 #include "caffe2/quantization/server/dnnlowp.h"
6 #include "caffe2/quantization/server/op_wrapper.h"
7 #include "caffe2/quantization/server/sigmoid.h"
8 
9 namespace caffe2 {
10 
11 template <typename T>
12 class LSTMUnitDNNLowPOp final : public LSTMUnitOp<CPUContext> {
13  static_assert(std::is_integral<T>::value, "Integral required.");
14 
15  public:
16  LSTMUnitDNNLowPOp(const OperatorDef& operator_def, Workspace* ws);
18  bool RunOnDevice() override;
19 
20  private:
21  const TensorCPU& InputTensorCPU_(int idx);
22  TensorCPU* OutputTensorCPU_(int idx);
23  bool GetQuantizationParameters_();
25 
26  bool drop_states_;
27  dnnlowp::Sigmoid<T> sigmoid_;
28  dnnlowp::Tanh<T> tanh_;
29 
30  dnnlowp::TensorQuantizationParams H_in_qparams_, C_in_qparams_, G_in_qparams_,
31  H_out_qparams_, C_out_qparams_;
32 
33  std::unique_ptr<OpWrapper<LSTMUnitOp<CPUContext>, T>> fp32_op_;
34  bool dequantize_output_{false}, measure_quantization_error_{false};
35 
36  std::unique_ptr<dnnlowp::QuantizationFactory> qfactory_;
37 
38  dnnlowp::QuantizationErrorStats cell_quantization_error_stats_,
39  hidden_quantization_error_stats_;
40 
41  bool arguments_parsed_{false};
42 }; // class LSTMUnitDNNLowPOp
43 
44 } // namespace caffe2
sigmoid(x) = (tanh(x/2) + 1)/2 Quantized sigmoid is computed as tanh under the hood, we just use different input/output quantization parameters.
Definition: sigmoid.h:13
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
We use the 3-region approach described in "Efficient VLSI Implementation of Neural Networks with Hype...
Definition: tanh.h:21
Wrap a floating-point operator with quantized inputs with type T.
Definition: op_wrapper.h:15