Caffe2 - C++ API
A deep learning, cross platform ML framework
elementwise_mul_dnnlowp_op.cc
1 #include "caffe2/operators/elementwise_mul_op.h"
2 #include "caffe2/quantization/server/elementwise_dnnlowp_op.h"
3 #include "caffe2/quantization/server/op_wrapper.h"
4 #include "caffe2/quantization/server/sigmoid.h"
5 
6 namespace caffe2 {
7 
8 using namespace std;
9 using namespace dnnlowp;
10 
11 using MulFp32Op =
12  BinaryElementwiseOp<NumericTypes, CPUContext, MulFunctor<CPUContext>>;
13 
14 template <typename T>
15 class MulDNNLowPOp : public BinaryElementwiseDNNLowPOp<T, MulFp32Op> {
16  public:
17  USE_OPERATOR_FUNCTIONS(CPUContext);
18  USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, MulFp32Op);
22 
23  MulDNNLowPOp(const OperatorDef& operator_def, Workspace* ws)
24  : BinaryElementwiseDNNLowPOp<T, MulFp32Op>(operator_def, ws) {}
25 
26  bool RunOnDevice() override {
27  if (!GetQuantizationParameters_()) {
28  return false;
29  }
30 
31  const auto& A = InputTensorCPU_(0);
32  const auto& B = InputTensorCPU_(1);
33  auto* C = OutputTensorCPU_(0);
34  CAFFE_ENFORCE(
35  &B != C || !enable_broadcast_,
36  "In-place is allowed only with the first tensor when broadcasting");
37  C->ResizeLike(A);
38 
39  // Quantize inputs if needed
40  vector<T> A_temp, B_temp;
41  const T* A_quantized =
42  QuantizeInputIfNeeded<T>(this, 0, in_qparams_[0], A_temp);
43  const T* B_quantized =
44  QuantizeInputIfNeeded<T>(this, 1, in_qparams_[1], B_temp);
45 
46  T* C_quantized = GetQuantizedOutputData_();
47 
48  if (!enable_broadcast_) {
49  CAFFE_ENFORCE_EQ(
50  A.sizes(),
51  B.sizes(),
52  "Dimension mismatch - did you forget to set broadcast=1?");
53 #ifdef _OPENMP
54 #pragma omp parallel for
55 #endif
56  for (int i = 0; i < C->size(); ++i) {
57  int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
58  (B_quantized[i] - in_qparams_[1].zero_point);
59  C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
60  }
61  } else if (B.size() == 1) {
62 #ifdef _OPENMP
63 #pragma omp parallel for
64 #endif
65  for (int i = 0; i < C->size(); ++i) {
66  int32_t raw = (A_quantized[i] - in_qparams_[0].zero_point) *
67  (B_quantized[0] - in_qparams_[1].zero_point);
68  C_quantized[i] = fbgemm::Requantize<T>(raw, requantization_params_);
69  }
70  } else {
71  size_t pre, n, post;
72  std::tie(pre, n, post) =
73  elementwise_ops_utils::ComputeLegacyBroadcastSizes(A, B, axis_);
74 #ifdef _OPENMP
75 #pragma omp parallel for
76 #endif
77  for (int i = 0; i < pre; ++i) {
78  for (int j = 0; j < n; ++j) {
79  for (int k = 0; k < post; ++k) {
80  int32_t raw = (A_quantized[((i * n) + j) * post + k] -
81  in_qparams_[0].zero_point) *
82  (B_quantized[j] - in_qparams_[1].zero_point);
83  C_quantized[((i * n) + j) * post + k] =
84  fbgemm::Requantize<T>(raw, requantization_params_);
85  }
86  }
87  }
88  }
89 
90  RunOnDeviceEpilogue_();
91 
92  return true;
93  }
94 
95  private:
96  bool GetQuantizationParameters_() {
97  // Choose quantization for A and B
98  in_qparams_[0] =
99  GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
100  in_qparams_[1] =
101  GetInputTensorQuantizationParamsOf(this, 1, qfactory_.get());
102 
103  GetOutputQuantizationParams_();
104 
105  float real_multiplier =
106  in_qparams_[0].scale * in_qparams_[1].scale / out_qparams_.scale;
107  requantization_params_ = qfactory_->ChooseRequantizationMultiplier(
108  real_multiplier, out_qparams_);
109 
110  return true;
111  }
112 }; // class MulDNNLowPOp
113 
114 REGISTER_CPU_OPERATOR_WITH_ENGINE(Mul, DNNLOWP, MulDNNLowPOp<uint8_t>);
115 REGISTER_CPU_OPERATOR_WITH_ENGINE(Int8Mul, DNNLOWP, MulDNNLowPOp<uint8_t>);
116 
117 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
does bound shape inference given a C2 net.
Definition: static.cpp:64
Definition: static.cpp:58