Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_relu_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_RELU_OP_H_
2 #define CAFFE2_OPERATORS_INT8_RELU_OP_H_
3 
4 #include <qnnpack.h>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/tensor_int8.h"
9 #include "caffe2/operators/quantized/int8_utils.h"
10 
11 namespace caffe2 {
12 
13 namespace int8 {
14 
15 class Int8ReluOp final : public Operator<CPUContext> {
16  public:
17  explicit Int8ReluOp(const OperatorDef& operator_def, Workspace* ws)
18  : Operator<CPUContext>(operator_def, ws), ws_(ws) {}
19 
20  ~Int8ReluOp() {
21  if (this->qnnpackOperator_ != nullptr) {
22  qnnp_delete_operator(this->qnnpackOperator_);
23  this->qnnpackOperator_ = nullptr;
24  }
25  }
26 
27  bool RunOnDevice() override {
28  const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
29  auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
30  Y->t.ResizeLike(X.t);
31  Y->scale = X.scale;
32  Y->zero_point = X.zero_point;
33  CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
34  CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
35  const int32_t Y_offset =
36  this->template GetSingleArgument<int>("Y_zero_point", 0);
37  const float Y_scale =
38  this->template GetSingleArgument<float>("Y_scale", 1.0f);
39  CHECK_EQ(Y_offset, X.zero_point);
40  CHECK_EQ(Y_scale, X.scale);
41 
42  initQNNPACK();
43 
44  if (this->qnnpackOperator_ == nullptr) {
45  const qnnp_status createStatus = qnnp_create_clamp_nc_u8(
46  1 /* channels */,
47  X.zero_point /* output min */,
48  255 /* output max */,
49  0 /* flags */,
50  &qnnpackOperator_);
51  CAFFE_ENFORCE(
52  createStatus == qnnp_status_success,
53  "failed to create QNNPACK Clamp operator");
54  CAFFE_ENFORCE(this->qnnpackOperator_ != nullptr);
55  }
56 
57  const qnnp_status setupStatus = qnnp_setup_clamp_nc_u8(
58  this->qnnpackOperator_,
59  X.t.numel() /* batch size */,
60  X.t.template data<uint8_t>(),
61  1 /* X stride */,
62  Y->t.template mutable_data<uint8_t>(),
63  1 /* Y stride */);
64  CAFFE_ENFORCE(
65  setupStatus == qnnp_status_success,
66  "failed to setup QNNPACK Clamp operator");
67 
68 #ifdef FBCODE_CAFFE2
69  const qnnp_status runStatus =
70  qnnp_run_operator(this->qnnpackOperator_, nullptr /* thread pool */);
71 #else
72  pthreadpool_t threadpool =
73  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
74  const qnnp_status runStatus =
75  qnnp_run_operator(this->qnnpackOperator_, threadpool);
76 #endif
77  CAFFE_ENFORCE(
78  runStatus == qnnp_status_success,
79  "failed to run QNNPACK Clamp operator");
80 
81  return true;
82  }
83 
84  private:
85  Workspace* ws_;
86  // QNNPACK Clamp operator
87  qnnp_operator_t qnnpackOperator_{nullptr};
88 };
89 
90 } // namespace int8
91 
92 } // namespace caffe2
93 
94 #endif // CAFFE2_OPERATORS_INT8_RELU_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13