Caffe2 - C++ API
A deep learning, cross platform ML framework
relu_dnnlowp_op.cc
1 #include "caffe2/quantization/server/relu_dnnlowp_op.h"
2 
3 #include <limits>
4 
5 namespace caffe2 {
6 
7 template <typename T>
8 bool ReluDNNLowPOp<T>::RunOnDevice() {
9  auto& X = InputIsType<int8::Int8TensorCPU>(0)
10  ? (this->template Input<int8::Int8TensorCPU>(0)).t
11  : Input(0);
12 
13  TensorCPU* Y = nullptr;
14  if (InputIsType<int8::Int8TensorCPU>(0)) {
15  // The output follows the same type as input because ReLU can be inplace
16  Y = &Outputs()[0]->template GetMutable<int8::Int8TensorCPU>()->t;
17  } else {
18  Y = Output(0);
19  }
20  Y->ResizeLike(X);
21 
22  using namespace dnnlowp;
23 
24  // Choose quantization params
25  TensorQuantizationParams in_qparams =
26  GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
27 
28  // Quantize input if needed
29  std::vector<T> X_temp, Y_temp;
30  const T* X_data = QuantizeInputIfNeeded(this, 0, in_qparams, X_temp);
31 
32  T* Y_data = nullptr;
33  if (X.template IsType<T>()) {
34  Y_data = Y->template mutable_data<T>();
35  } else {
36  Y_temp.resize(Y->numel());
37  Y_data = Y_temp.data();
38  }
39 
40  CAFFE_ENFORCE_GE(in_qparams.zero_point, std::numeric_limits<T>::lowest());
41  CAFFE_ENFORCE_LE(in_qparams.zero_point, std::numeric_limits<T>::max());
42  const int N = X.numel();
43  if (in_qparams.zero_point == std::numeric_limits<T>::lowest()) {
44  if (Y_data != X_data) {
45  std::memcpy(Y_data, X_data, N * sizeof(T));
46  }
47  } else {
48  if (GetCpuId().avx2()) {
49  internal::ReluAVX2<T>(N, in_qparams.zero_point, X_data, Y_data);
50  } else {
51 #ifdef _OPENMP
52 #pragma omp parallel for
53 #endif
54  for (int i = 0; i < N; ++i) {
55  Y_data[i] = std::max(X_data[i], static_cast<T>(in_qparams.zero_point));
56  }
57  }
58  }
59 
60  // Even if there is a pre-chosen quantization parameters for the output,
61  // it is ignored because relu output quantization should be same as the
62  // input.
63  PropagateOutputTensorQuantizationParams(this, 0, in_qparams);
64 
65  // If input was not quantized, output should be dequantized because ReLU
66  // can be inplace.
67  if (!X.template IsType<T>()) {
68  fbgemm::Dequantize<T>(
69  Y_data, Y->template mutable_data<float>(), Y->numel(), in_qparams);
70  }
71 
72  return true;
73 }
74 
75 REGISTER_CPU_OPERATOR_WITH_ENGINE(Relu, DNNLOWP, ReluDNNLowPOp<uint8_t>);
76 REGISTER_CPU_OPERATOR_WITH_ENGINE(Relu, DNNLOWP_16, ReluDNNLowPOp<uint16_t>);
77 
78 REGISTER_CPU_OPERATOR_WITH_ENGINE(Int8Relu, DNNLOWP, ReluDNNLowPOp<uint8_t>);
79 
80 } // namespace caffe2
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:2