Caffe2 - C++ API
A deep learning, cross platform ML framework
caffe2_dnnlowp_utils.h
1 #pragma once
2 
3 #include "caffe2/core/operator.h"
4 #include "caffe2/quantization/server/dnnlowp.h"
5 #include "caffe2/utils/eigen_utils.h"
6 
7 namespace dnnlowp {
8 
13 void PropagateOutputTensorQuantizationParams(
15  int output_index,
16  const TensorQuantizationParams& qparams);
17 
25 TensorQuantizationParams GetInputTensorQuantizationParamsOf(
27  int input_index,
28  const QuantizationFactory* qfactory,
29  bool is_weight = false);
30 
31 void SetStaticQuantizationParams(
33  int output_index,
34  const TensorQuantizationParams& qparams);
35 
40 bool HasStaticQuantization(
41  const caffe2::OperatorBase* op,
42  int output_index = 0);
43 
48 TensorQuantizationParams GetStaticQuantizationParamsOf(
49  const caffe2::OperatorBase* op,
50  int output_index);
51 
58 template <typename T>
59 const T* QuantizeInputIfNeeded(
61  int input_index,
62  const TensorQuantizationParams& qparams,
63  std::vector<T>& temp);
64 
65 template <typename T>
66 const T* RowWiseQuantizeInputIfNeeded(
68  int input_index,
69  const std::vector<TensorQuantizationParams>& qparams,
70  std::vector<T>& temp);
71 
73  float sum_sq{0}, sum_err_sq{0};
74  float max_abs_err{0};
75  // actual and reference values that resulted in max_abs_err
76  float max_err_actual{0}, max_err_ref{0};
77  int measure_cnt{0};
78 };
79 
80 void MeasureQuantizationError(
81  const float* actual,
82  const float* ref,
83  size_t len,
85 
86 void ReportQuantizationError(
87  const caffe2::OperatorBase* op,
88  const QuantizationErrorStats& stat);
89 
93 std::unique_ptr<QuantizationFactory> GetQuantizationFactoryOf(
94  const caffe2::OperatorBase* op);
95 
96 void AdjustOutputTensorQuantizationParamsWithFollowedBy(
98  const std::string& followed_by);
99 
100 void ParseDNNLowPOperatorArguments(
102  bool* dequantize_output = nullptr,
103  bool* measure_quantization_error = nullptr,
104  std::string* followed_by = nullptr);
105 
106 caffe2::NetDef AddScaleZeroOffsetArgumentsWithHistogram(
107  caffe2::NetDef net_def,
108  const std::string& histogram_file_name);
109 
110 } // namespace dnnlowp