3 #include "caffe2/core/operator.h" 4 #include "caffe2/core/tensor_int8.h" 5 #include "caffe2/quantization/server/caffe2_dnnlowp_utils.h" 6 #include "caffe2/quantization/server/dnnlowp.h" 14 template <
typename OpType,
typename T>
18 : op_(op), qfactory_(qfactory) {
19 for (
auto name : op->debug_def().input()) {
20 local_input_blobs_.push_back(local_ws_.
CreateBlob(name));
21 CHECK_NOTNULL(local_input_blobs_.back());
23 OperatorDef def = op->debug_def();
24 local_op_.reset(
new OpType(def, &local_ws_));
25 for (
auto name : def.output()) {
26 local_output_blobs_.push_back(local_ws_.
GetBlob(name));
27 CHECK_NOTNULL(local_output_blobs_.back());
31 void DequantizeInput() {
32 const OperatorDef& def = op_->debug_def();
35 for (
int i = 0; i < op_->InputSize(); ++i) {
39 BlobGetMutableTensor(local_input_blobs_[i], CPU);
43 float_tensor->ResizeLike(qtensor);
44 fbgemm::Dequantize<T>(
46 float_tensor->template mutable_data<float>(),
48 dnnlowp::GetInputTensorQuantizationParamsOf(op_, i, qfactory_));
50 local_input_blobs_[i]->ShareExternal(
51 const_cast<void*>(op_->Inputs()[i]->GetRaw()),
52 op_->Inputs()[i]->meta());
58 return local_op_.get();
61 dnnlowp::TensorQuantizationParams GetOutputQuantizationParams(
67 auto& out_tensor = local_output_blobs_[index]->template Get<TensorCPU>();
69 out_tensor.template data<float>(), &min, &max, out_tensor.numel());
70 if (op_->OperatorBase::GetSingleArgument<std::string>(
"followed_by",
"") ==
72 min = std::max(0.0f, min);
73 max = std::max(0.0f, max);
82 std::vector<Blob*> local_input_blobs_;
83 std::vector<Blob*> local_output_blobs_;
84 std::unique_ptr<OpType> local_op_;
Blob * CreateBlob(const string &name)
Creates a blob of the given name.
TensorQuantizationParams ChooseQuantizationParams(float min, float max, int precision, bool preserve_sparsity, bool is_signed=false) const
Choose quantization scale and zero_point that maps floating-point range [min, max] to the integer ran...
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
const Blob * GetBlob(const string &name) const
Gets the blob with the given name as a const pointer.
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Wrap a floating-point operator with quantized inputs with type T.