1 #include "utility_dnnlowp_ops.h" 6 GatherDNNLowPOp<T>::GatherDNNLowPOp(
7 const OperatorDef& operator_def,
9 : GatherOp<CPUContext>(operator_def, ws),
10 qfactory_(
dnnlowp::GetQuantizationFactoryOf(this)) {}
13 GatherDNNLowPOp<T>::~GatherDNNLowPOp() {
14 if (measure_quantization_error_) {
15 dnnlowp::ReportQuantizationError(
this, quantization_error_stats_);
20 bool GatherDNNLowPOp<T>::RunOnDevice() {
23 if (!arguments_parsed_) {
24 dnnlowp::ParseDNNLowPOperatorArguments(
25 this, &dequantize_output_, &measure_quantization_error_);
26 arguments_parsed_ =
true;
29 if (!InputIsType<int8::Int8TensorCPU>(DATA)) {
30 if (dequantize_output_) {
31 return GatherOp<CPUContext>::RunOnDevice();
34 Fp32Op_()->DequantizeInput();
36 if (!Fp32Op_()->Get()->RunOnDevice()) {
40 int8::Int8TensorCPU* output =
41 Outputs()[0]->template GetMutable<int8::Int8TensorCPU>();
43 output->t.ResizeLike(*Fp32Op_()->Get()->Output(0));
44 T* out_data = output->t.template mutable_data<T>();
46 TensorQuantizationParams out_qparams;
47 if (HasStaticQuantization(
this)) {
48 out_qparams = GetStaticQuantizationParamsOf(
this, 0);
50 out_qparams = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
54 static_cast<const float*
>(Fp32Op_()->Get()->Output(0)->raw_data()),
59 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams);
62 DispatchHelper<TensorTypes<int32_t, int64_t>>::call(
this, Input(INDICES));
64 TensorQuantizationParams in_qparams =
65 GetInputTensorQuantizationParamsOf(
this, 0, qfactory_.get());
67 PropagateOutputTensorQuantizationParams(
this, 0, in_qparams);
73 REGISTER_CPU_OPERATOR_WITH_ENGINE(Gather, DNNLOWP, GatherDNNLowPOp<uint8_t>);
74 REGISTER_CPU_OPERATOR_WITH_ENGINE(
77 GatherDNNLowPOp<uint8_t>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...