1 #ifndef CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/tensor_int8.h" 9 #include "caffe2/operators/conv_pool_op_base.h" 10 #include "caffe2/operators/quantized/int8_utils.h" 16 template <Activation Ac>
19 template <
class... Args>
22 OPERATOR_NEEDS_FEATURE(
23 this->order_ == StorageOrder::NHWC,
"Int8 only supports NHWC order.");
27 if (this->qnnpackOperator_ !=
nullptr) {
28 qnnp_delete_operator(this->qnnpackOperator_);
29 this->qnnpackOperator_ =
nullptr;
33 bool RunOnDeviceWithOrderNHWC()
override {
34 const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
35 auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
37 Y->zero_point = X.zero_point;
38 const int32_t Y_zero_point =
39 this->
template GetSingleArgument<int>(
"Y_zero_point", 0);
40 const float Y_scale = this->
template GetSingleArgument<float>(
"Y_scale", 1);
41 CHECK_EQ(Y_zero_point, X.zero_point);
42 CHECK_EQ(Y_scale, X.scale);
44 CHECK_EQ(X.t.dim(), 4);
45 const int channels = X.t.dim32(3);
50 if (this->qnnpackOperator_ ==
nullptr) {
51 const qnnp_status createStatus = qnnp_create_max_pooling2d_nhwc_u8(
52 pad_t(), pad_r(), pad_b(), pad_l(),
53 kernel_h(), kernel_w(),
54 stride_h(), stride_w(),
57 activationLimits(Y->scale, Y->zero_point, Ac).first,
58 activationLimits(Y->scale, Y->zero_point, Ac).second,
60 &this->qnnpackOperator_);
62 createStatus == qnnp_status_success,
63 "failed to create QNNPACK Max Pooling operator");
64 CAFFE_ENFORCE(this->qnnpackOperator_ !=
nullptr);
67 const qnnp_status setupStatus = qnnp_setup_max_pooling2d_nhwc_u8(
68 this->qnnpackOperator_,
69 X.t.dim32(0), X.t.dim32(1), X.t.dim32(2),
70 X.t.template data<uint8_t>(), channels,
71 Y->t.template mutable_data<uint8_t>(), channels,
74 setupStatus == qnnp_status_success,
75 "failed to setup QNNPACK Max Pooling operator");
78 const qnnp_status runStatus =
79 qnnp_run_operator(this->qnnpackOperator_,
nullptr );
81 pthreadpool_t threadpool =
82 reinterpret_cast<pthreadpool_t
>(ws_->GetThreadPool());
83 const qnnp_status runStatus =
84 qnnp_run_operator(this->qnnpackOperator_, threadpool);
87 runStatus == qnnp_status_success,
88 "failed to run QNNPACK Max Pooling operator");
94 qnnp_operator_t qnnpackOperator_{
nullptr};
101 #endif // CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_ A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...