1 #ifndef CAFFE2_OPERATORS_INT8_CONV_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_CONV_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/tensor_int8.h" 9 #include "caffe2/operators/conv_op_shared.h" 10 #include "caffe2/operators/conv_pool_op_base.h" 11 #include "caffe2/operators/quantized/int8_utils.h" 17 template <Activation Ac>
21 template <
class... Args>
24 OPERATOR_NEEDS_FEATURE(
25 this->order_ == StorageOrder::NHWC,
26 "Int8Conv only supports NHWC order");
27 createSharedBuffer<CPUContext>(ws_);
31 if (this->qnnpackObject_ !=
nullptr) {
32 qnnp_delete_operator(this->qnnpackObject_);
33 this->qnnpackObject_ =
nullptr;
37 bool RunOnDeviceWithOrderNHWC()
override {
38 CAFFE_ENFORCE_EQ(Inputs().size(), 3);
39 const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
40 const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
41 const auto&
B = Inputs()[2]->template Get<Int8TensorCPU>();
42 auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
43 const int32_t Y_offset =
44 this->
template GetSingleArgument<int>(
"Y_zero_point", 0);
45 double Y_scale = this->
template GetSingleArgument<float>(
"Y_scale", 1);
49 Y->zero_point = Y_offset;
51 const auto M = W.t.size(0);
52 const auto KH = W.t.size(1);
53 const auto KW = W.t.size(2);
54 const auto KC = W.t.size(3);
55 const auto C = X.t.dim32(3);
56 const bool isDepthwise = this->group_ > 1 && this->group_ ==
M &&
57 this->group_ ==
C && KC == 1 && KH * KW == 9 && dilation_w() == 1;
59 CHECK_EQ(Y->t.dim32(3),
M);
60 runWithSharedBuffer<CPUContext>(ws_, [&](
Tensor* buffer) {
63 pthreadpool_t threadpool =
64 reinterpret_cast<pthreadpool_t
>(ws_->GetThreadPool());
66 if (this->qnnpackObject_ ==
nullptr) {
68 C % this->group_ == 0,
69 "number of input channels must be divisible by groups count");
71 M % this->group_ == 0,
72 "number of output channels must be divisible by groups count");
73 const qnnp_status createStatus = qnnp_create_convolution2d_nhwc_q8(
92 W.t.template data<uint8_t>(),
93 B.t.template data<int32_t>(),
100 activationLimits(Y->scale, Y->zero_point, Ac).first,
101 activationLimits(Y->scale, Y->zero_point, Ac).second,
103 &this->qnnpackObject_);
105 createStatus == qnnp_status_success,
106 "failed to create QNNPACK convolution object");
107 CAFFE_ENFORCE(this->qnnpackObject_ !=
nullptr);
110 uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
111 if ((isDepthwise && this->group_ < 8) ||
112 (!isDepthwise &&
C / this->group_ < 8)) {
113 buffer->Resize(std::vector<int64_t>{X.t.numel() + 8});
114 inputPtr = buffer->template mutable_data<uint8_t>() + 8;
115 memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.numel());
118 if (lastBatchSize_ != static_cast<size_t>(X.t.size(0)) ||
119 lastInputHeight_ != static_cast<size_t>(X.t.size(1)) ||
120 lastInputWidth_ != static_cast<size_t>(X.t.size(2)) ||
121 lastInputPointer_ != inputPtr ||
122 lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
123 const qnnp_status setupStatus = qnnp_setup_convolution2d_nhwc_q8(
124 this->qnnpackObject_,
130 Y->t.template mutable_data<uint8_t>(),
134 setupStatus == qnnp_status_success,
135 "failed to setup QNNPACK convolution object");
137 lastBatchSize_ =
static_cast<size_t>(X.t.size(0));
138 lastInputHeight_ =
static_cast<size_t>(X.t.size(1));
139 lastInputWidth_ =
static_cast<size_t>(X.t.size(2));
140 lastInputPointer_ = inputPtr;
141 lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
145 const qnnp_status runStatus =
146 qnnp_run_operator(this->qnnpackObject_,
nullptr );
148 const qnnp_status runStatus =
149 qnnp_run_operator(this->qnnpackObject_, threadpool);
152 runStatus == qnnp_status_success,
153 "failed to run QNNPACK convolution");
160 qnnp_operator_t qnnpackObject_{
nullptr};
162 size_t lastBatchSize_{0};
164 size_t lastInputHeight_{0};
166 size_t lastInputWidth_{0};
168 const void* lastInputPointer_{
nullptr};
170 void* lastOutputPointer_{
nullptr};
177 #endif // CAFFE2_OPERATORS_INT8_CONV_OP_H_
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...