Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_average_pool_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
2 #define CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
3 
4 #include <qnnpack.h>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/tensor_int8.h"
9 #include "caffe2/operators/conv_pool_op_base.h"
10 #include "caffe2/operators/quantized/int8_utils.h"
11 
12 namespace caffe2 {
13 
14 namespace int8 {
15 
16 template <Activation Ac>
17 class Int8AveragePoolOp final : public ConvPoolOpBase<CPUContext> {
18  public:
19  template <class... Args>
20  explicit Int8AveragePoolOp(Args&&... args)
21  : ConvPoolOpBase<CPUContext>(std::forward<Args>(args)...) {
22  OPERATOR_NEEDS_FEATURE(
23  this->order_ == StorageOrder::NHWC, "Int8 only supports NHWC order.");
24  }
25 
26  ~Int8AveragePoolOp() {
27  if (this->qnnpackOperator_ != nullptr) {
28  qnnp_delete_operator(this->qnnpackOperator_);
29  this->qnnpackOperator_ = nullptr;
30  }
31  if (this->qnnpackGlobalOperator_ != nullptr) {
32  qnnp_delete_operator(this->qnnpackGlobalOperator_);
33  this->qnnpackGlobalOperator_ = nullptr;
34  }
35  }
36 
37  bool RunOnDeviceWithOrderNHWC() override {
38  const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
39  auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
40  int32_t Y_zero_point =
41  this->template GetSingleArgument<int>("Y_zero_point", 0);
42  auto Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
43  Y->scale = Y_scale;
44  Y->zero_point = Y_zero_point;
45 
46  CHECK_EQ(X.t.dim(), 4);
47  const int channels = X.t.dim32(3);
48  ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
49 
50  initQNNPACK();
51 
52  const bool anyPadding =
53  pad_t() != 0 || pad_r() != 0 || pad_b() != 0 || pad_l() != 0;
54  const bool anyStride = stride_h() > 1 || stride_w() > 1;
55  const bool globalPooling = !anyPadding && !anyStride &&
56  (X.t.dim32(1) == kernel_h() && X.t.dim32(2) == kernel_w());
57  if (globalPooling) {
58  if (this->qnnpackGlobalOperator_ == nullptr) {
59  const qnnp_status createStatus =
60  qnnp_create_global_average_pooling_nwc_q8(
61  channels,
62  X.zero_point, X.scale,
63  Y->zero_point, Y->scale,
64  activationLimits(Y->scale, Y->zero_point, Ac).first,
65  activationLimits(Y->scale, Y->zero_point, Ac).second,
66  0 /* flags */,
67  &this->qnnpackGlobalOperator_);
68  CAFFE_ENFORCE(
69  createStatus == qnnp_status_success,
70  "failed to create QNNPACK Global Average Pooling operator");
71  CAFFE_ENFORCE(this->qnnpackGlobalOperator_ != nullptr);
72  }
73 
74  const qnnp_status setupStatus = qnnp_setup_global_average_pooling_nwc_q8(
75  this->qnnpackGlobalOperator_,
76  X.t.dim32(0), X.t.dim32(1) * X.t.dim32(2),
77  X.t.template data<uint8_t>(), channels,
78  Y->t.template mutable_data<uint8_t>(), channels);
79  CAFFE_ENFORCE(
80  setupStatus == qnnp_status_success,
81  "failed to setup QNNPACK Global Average Pooling operator");
82 
83 #ifdef FBCODE_CAFFE2
84  const qnnp_status runStatus =
85  qnnp_run_operator(this->qnnpackGlobalOperator_,
86  nullptr /* thread pool */);
87 #else
88  pthreadpool_t threadpool =
89  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
90  const qnnp_status runStatus =
91  qnnp_run_operator(this->qnnpackGlobalOperator_, threadpool);
92 #endif
93  CAFFE_ENFORCE(
94  runStatus == qnnp_status_success,
95  "failed to run QNNPACK Global Average Pooling operator");
96  } else {
97  if (this->qnnpackOperator_ == nullptr) {
98  const qnnp_status createStatus = qnnp_create_average_pooling2d_nhwc_q8(
99  pad_t(), pad_r(), pad_b(), pad_l(),
100  kernel_h(), kernel_w(),
101  stride_h(), stride_w(),
102  channels,
103  X.zero_point, X.scale,
104  Y->zero_point, Y->scale,
105  activationLimits(Y->scale, Y->zero_point, Ac).first,
106  activationLimits(Y->scale, Y->zero_point, Ac).second,
107  0 /* flags */,
108  &this->qnnpackOperator_);
109  CAFFE_ENFORCE(
110  createStatus == qnnp_status_success,
111  "failed to create QNNPACK Average Pooling operator");
112  CAFFE_ENFORCE(this->qnnpackOperator_ != nullptr);
113  }
114 
115  const qnnp_status setupStatus = qnnp_setup_average_pooling2d_nhwc_q8(
116  this->qnnpackOperator_,
117  X.t.dim32(0), X.t.dim32(1), X.t.dim32(2),
118  X.t.template data<uint8_t>(), channels,
119  Y->t.template mutable_data<uint8_t>(), channels,
120  nullptr /* thread pool */);
121  CAFFE_ENFORCE(
122  setupStatus == qnnp_status_success,
123  "failed to setup QNNPACK Average Pooling operator");
124 
125 #ifdef FBCODE_CAFFE2
126  const qnnp_status runStatus =
127  qnnp_run_operator(this->qnnpackOperator_, nullptr /* thread pool */);
128 #else
129  pthreadpool_t threadpool =
130  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
131  const qnnp_status runStatus =
132  qnnp_run_operator(this->qnnpackOperator_, threadpool);
133 #endif
134  CAFFE_ENFORCE(
135  runStatus == qnnp_status_success,
136  "failed to run QNNPACK Average Pooling operator");
137  }
138 
139  return true;
140  }
141  private:
142  // QNNPACK Average Pooling operator
143  qnnp_operator_t qnnpackOperator_{nullptr};
144  // QNNPACK Global Average Pooling operator
145  qnnp_operator_t qnnpackGlobalOperator_{nullptr};
146 };
147 
148 } // namespace int8
149 
150 } // namespace caffe2
151 
152 #endif // CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13