Caffe2 - C++ API
A deep learning, cross platform ML framework
utility_dnnlowp_ops.cc
1 #include "utility_dnnlowp_ops.h"
2 
3 namespace caffe2 {
4 
5 template <typename T>
6 GatherDNNLowPOp<T>::GatherDNNLowPOp(
7  const OperatorDef& operator_def,
8  Workspace* ws)
9  : GatherOp<CPUContext>(operator_def, ws),
10  qfactory_(dnnlowp::GetQuantizationFactoryOf(this)) {}
11 
12 template <typename T>
13 GatherDNNLowPOp<T>::~GatherDNNLowPOp() {
14  if (measure_quantization_error_) {
15  dnnlowp::ReportQuantizationError(this, quantization_error_stats_);
16  }
17 }
18 
19 template <typename T>
20 bool GatherDNNLowPOp<T>::RunOnDevice() {
21  using namespace dnnlowp;
22 
23  if (!arguments_parsed_) {
24  dnnlowp::ParseDNNLowPOperatorArguments(
25  this, &dequantize_output_, &measure_quantization_error_);
26  arguments_parsed_ = true;
27  }
28 
29  if (!InputIsType<int8::Int8TensorCPU>(DATA)) {
30  if (dequantize_output_) {
31  return GatherOp<CPUContext>::RunOnDevice();
32  } else {
33  // If input or output is float, delegate to fp32 op
34  Fp32Op_()->DequantizeInput();
35  // dequantize input if it's not already float
36  if (!Fp32Op_()->Get()->RunOnDevice()) {
37  return false;
38  }
39 
40  int8::Int8TensorCPU* output =
41  Outputs()[0]->template GetMutable<int8::Int8TensorCPU>();
42 
43  output->t.ResizeLike(*Fp32Op_()->Get()->Output(0));
44  T* out_data = output->t.template mutable_data<T>();
45 
46  TensorQuantizationParams out_qparams;
47  if (HasStaticQuantization(this)) {
48  out_qparams = GetStaticQuantizationParamsOf(this, 0);
49  } else {
50  out_qparams = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
51  }
52 
53  fbgemm::Quantize<T>(
54  static_cast<const float*>(Fp32Op_()->Get()->Output(0)->raw_data()),
55  out_data,
56  output->t.numel(),
57  out_qparams);
58 
59  PropagateOutputTensorQuantizationParams(this, 0, out_qparams);
60  }
61  } else {
62  DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(INDICES));
63 
64  TensorQuantizationParams in_qparams =
65  GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
66 
67  PropagateOutputTensorQuantizationParams(this, 0, in_qparams);
68  }
69 
70  return true;
71 }
72 
73 REGISTER_CPU_OPERATOR_WITH_ENGINE(Gather, DNNLOWP, GatherDNNLowPOp<uint8_t>);
74 REGISTER_CPU_OPERATOR_WITH_ENGINE(
75  Int8Gather,
76  DNNLOWP,
77  GatherDNNLowPOp<uint8_t>);
78 
79 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13