1 #ifndef CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/tensor_int8.h" 9 #include "caffe2/operators/conv_pool_op_base.h" 10 #include "caffe2/operators/quantized/int8_utils.h" 16 template <Activation Ac>
19 template <
class... Args>
22 OPERATOR_NEEDS_FEATURE(
23 this->order_ == StorageOrder::NHWC,
"Int8 only supports NHWC order.");
26 ~Int8AveragePoolOp() {
27 if (this->qnnpackOperator_ !=
nullptr) {
28 qnnp_delete_operator(this->qnnpackOperator_);
29 this->qnnpackOperator_ =
nullptr;
31 if (this->qnnpackGlobalOperator_ !=
nullptr) {
32 qnnp_delete_operator(this->qnnpackGlobalOperator_);
33 this->qnnpackGlobalOperator_ =
nullptr;
37 bool RunOnDeviceWithOrderNHWC()
override {
38 const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
39 auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
40 int32_t Y_zero_point =
41 this->
template GetSingleArgument<int>(
"Y_zero_point", 0);
42 auto Y_scale = this->
template GetSingleArgument<float>(
"Y_scale", 1);
44 Y->zero_point = Y_zero_point;
46 CHECK_EQ(X.t.dim(), 4);
47 const int channels = X.t.dim32(3);
52 const bool anyPadding =
53 pad_t() != 0 || pad_r() != 0 || pad_b() != 0 || pad_l() != 0;
54 const bool anyStride = stride_h() > 1 || stride_w() > 1;
55 const bool globalPooling = !anyPadding && !anyStride &&
56 (X.t.dim32(1) == kernel_h() && X.t.dim32(2) == kernel_w());
58 if (this->qnnpackGlobalOperator_ ==
nullptr) {
59 const qnnp_status createStatus =
60 qnnp_create_global_average_pooling_nwc_q8(
62 X.zero_point, X.scale,
63 Y->zero_point, Y->scale,
64 activationLimits(Y->scale, Y->zero_point, Ac).first,
65 activationLimits(Y->scale, Y->zero_point, Ac).second,
67 &this->qnnpackGlobalOperator_);
69 createStatus == qnnp_status_success,
70 "failed to create QNNPACK Global Average Pooling operator");
71 CAFFE_ENFORCE(this->qnnpackGlobalOperator_ !=
nullptr);
74 const qnnp_status setupStatus = qnnp_setup_global_average_pooling_nwc_q8(
75 this->qnnpackGlobalOperator_,
76 X.t.dim32(0), X.t.dim32(1) * X.t.dim32(2),
77 X.t.template data<uint8_t>(), channels,
78 Y->t.template mutable_data<uint8_t>(), channels);
80 setupStatus == qnnp_status_success,
81 "failed to setup QNNPACK Global Average Pooling operator");
84 const qnnp_status runStatus =
85 qnnp_run_operator(this->qnnpackGlobalOperator_,
88 pthreadpool_t threadpool =
89 reinterpret_cast<pthreadpool_t
>(ws_->GetThreadPool());
90 const qnnp_status runStatus =
91 qnnp_run_operator(this->qnnpackGlobalOperator_, threadpool);
94 runStatus == qnnp_status_success,
95 "failed to run QNNPACK Global Average Pooling operator");
97 if (this->qnnpackOperator_ ==
nullptr) {
98 const qnnp_status createStatus = qnnp_create_average_pooling2d_nhwc_q8(
99 pad_t(), pad_r(), pad_b(), pad_l(),
100 kernel_h(), kernel_w(),
101 stride_h(), stride_w(),
103 X.zero_point, X.scale,
104 Y->zero_point, Y->scale,
105 activationLimits(Y->scale, Y->zero_point, Ac).first,
106 activationLimits(Y->scale, Y->zero_point, Ac).second,
108 &this->qnnpackOperator_);
110 createStatus == qnnp_status_success,
111 "failed to create QNNPACK Average Pooling operator");
112 CAFFE_ENFORCE(this->qnnpackOperator_ !=
nullptr);
115 const qnnp_status setupStatus = qnnp_setup_average_pooling2d_nhwc_q8(
116 this->qnnpackOperator_,
117 X.t.dim32(0), X.t.dim32(1), X.t.dim32(2),
118 X.t.template data<uint8_t>(), channels,
119 Y->t.template mutable_data<uint8_t>(), channels,
122 setupStatus == qnnp_status_success,
123 "failed to setup QNNPACK Average Pooling operator");
126 const qnnp_status runStatus =
127 qnnp_run_operator(this->qnnpackOperator_,
nullptr );
129 pthreadpool_t threadpool =
130 reinterpret_cast<pthreadpool_t
>(ws_->GetThreadPool());
131 const qnnp_status runStatus =
132 qnnp_run_operator(this->qnnpackOperator_, threadpool);
135 runStatus == qnnp_status_success,
136 "failed to run QNNPACK Average Pooling operator");
143 qnnp_operator_t qnnpackOperator_{
nullptr};
145 qnnp_operator_t qnnpackGlobalOperator_{
nullptr};
152 #endif // CAFFE2_OPERATORS_INT8_AVERAGE_POOL_OP_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...