Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_max_pool_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
2 #define CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
3 
4 #include <qnnpack.h>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/tensor_int8.h"
9 #include "caffe2/operators/conv_pool_op_base.h"
10 #include "caffe2/operators/quantized/int8_utils.h"
11 
12 namespace caffe2 {
13 
14 namespace int8 {
15 
16 template <Activation Ac>
17 class Int8MaxPoolOp final : public ConvPoolOpBase<CPUContext> {
18  public:
19  template <class... Args>
20  explicit Int8MaxPoolOp(Args&&... args)
21  : ConvPoolOpBase<CPUContext>(std::forward<Args>(args)...) {
22  OPERATOR_NEEDS_FEATURE(
23  this->order_ == StorageOrder::NHWC, "Int8 only supports NHWC order.");
24  }
25 
26  ~Int8MaxPoolOp() {
27  if (this->qnnpackOperator_ != nullptr) {
28  qnnp_delete_operator(this->qnnpackOperator_);
29  this->qnnpackOperator_ = nullptr;
30  }
31  }
32 
33  bool RunOnDeviceWithOrderNHWC() override {
34  const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
35  auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
36  Y->scale = X.scale;
37  Y->zero_point = X.zero_point;
38  const int32_t Y_zero_point =
39  this->template GetSingleArgument<int>("Y_zero_point", 0);
40  const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
41  CHECK_EQ(Y_zero_point, X.zero_point);
42  CHECK_EQ(Y_scale, X.scale);
43 
44  CHECK_EQ(X.t.dim(), 4);
45  const int channels = X.t.dim32(3);
46  ConvPoolOpBase<CPUContext>::SetOutputSize(X.t, &(Y->t), channels);
47 
48  initQNNPACK();
49 
50  if (this->qnnpackOperator_ == nullptr) {
51  const qnnp_status createStatus = qnnp_create_max_pooling2d_nhwc_u8(
52  pad_t(), pad_r(), pad_b(), pad_l(),
53  kernel_h(), kernel_w(),
54  stride_h(), stride_w(),
55  1 /* dilation height */, 1 /* dilation width */,
56  channels,
57  activationLimits(Y->scale, Y->zero_point, Ac).first,
58  activationLimits(Y->scale, Y->zero_point, Ac).second,
59  0 /* flags */,
60  &this->qnnpackOperator_);
61  CAFFE_ENFORCE(
62  createStatus == qnnp_status_success,
63  "failed to create QNNPACK Max Pooling operator");
64  CAFFE_ENFORCE(this->qnnpackOperator_ != nullptr);
65  }
66 
67  const qnnp_status setupStatus = qnnp_setup_max_pooling2d_nhwc_u8(
68  this->qnnpackOperator_,
69  X.t.dim32(0), X.t.dim32(1), X.t.dim32(2),
70  X.t.template data<uint8_t>(), channels,
71  Y->t.template mutable_data<uint8_t>(), channels,
72  nullptr /* thread pool */);
73  CAFFE_ENFORCE(
74  setupStatus == qnnp_status_success,
75  "failed to setup QNNPACK Max Pooling operator");
76 
77 #ifdef FBCODE_CAFFE2
78  const qnnp_status runStatus =
79  qnnp_run_operator(this->qnnpackOperator_, nullptr /* thread pool */);
80 #else
81  pthreadpool_t threadpool =
82  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
83  const qnnp_status runStatus =
84  qnnp_run_operator(this->qnnpackOperator_, threadpool);
85 #endif
86  CAFFE_ENFORCE(
87  runStatus == qnnp_status_success,
88  "failed to run QNNPACK Max Pooling operator");
89  return true;
90  }
91 
92  private:
93  // QNNPACK Max Pooling operator
94  qnnp_operator_t qnnpackOperator_{nullptr};
95 };
96 
97 } // namespace int8
98 
99 } // namespace caffe2
100 
101 #endif // CAFFE2_OPERATORS_INT8_MAX_POOL_OP_H_
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13