1 #ifndef CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_ 2 #define CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_ 6 #include "caffe2/core/context.h" 7 #include "caffe2/core/operator.h" 8 #include "caffe2/core/tensor_int8.h" 9 #include "caffe2/operators/conv_op_shared.h" 10 #include "caffe2/operators/conv_transpose_unpool_op_base.h" 11 #include "caffe2/operators/quantized/int8_utils.h" 19 USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(
CPUContext);
20 template <
class... Args>
23 OPERATOR_NEEDS_FEATURE(
24 this->order_ == StorageOrder::NHWC,
25 "Int8ConvTransposeOp only supports NHWC order");
26 createSharedBuffer<CPUContext>(ws_);
29 ~Int8ConvTransposeOp() {
30 if (this->qnnpackObject_ !=
nullptr) {
31 qnnp_delete_operator(this->qnnpackObject_);
32 this->qnnpackObject_ =
nullptr;
36 bool RunOnDeviceWithOrderNHWC()
override {
37 CAFFE_ENFORCE_EQ(Inputs().size(), 3);
38 const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
39 const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
40 const auto&
B = Inputs()[2]->template Get<Int8TensorCPU>();
41 auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
42 const auto X_offset = -X.zero_point;
43 const auto W_offset = -W.zero_point;
44 const int32_t Y_offset =
45 this->
template GetSingleArgument<int>(
"Y_zero_point", 0);
46 double Y_scale = this->
template GetSingleArgument<float>(
"Y_scale", 1);
48 Y->zero_point = Y_offset;
50 const auto N = X.t.size(0);
51 const auto IH = X.t.size(1);
52 const auto IW = X.t.size(2);
53 const auto IC = X.t.size(3);
55 CHECK_EQ(IC, W.t.size(0));
56 const auto KH = W.t.size(1);
57 const auto KW = W.t.size(2);
58 const auto OC = W.t.size(3);
62 CHECK_EQ(OC, Y->t.size(3));
64 runWithSharedBuffer<CPUContext>(ws_, [&](
Tensor* buffer) {
67 pthreadpool_t threadpool =
68 reinterpret_cast<pthreadpool_t
>(ws_->GetThreadPool());
70 if (this->qnnpackObject_ ==
nullptr) {
71 const qnnp_status createStatus = qnnp_create_deconvolution2d_nhwc_q8(
92 W.t.template data<uint8_t>(),
93 B.t.template data<int32_t>(),
100 std::numeric_limits<uint8_t>::min(),
101 std::numeric_limits<uint8_t>::max(),
103 &this->qnnpackObject_);
105 createStatus == qnnp_status_success,
106 "failed to create QNNPACK convolution object");
107 CAFFE_ENFORCE(this->qnnpackObject_ !=
nullptr);
110 uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
112 buffer->Resize(std::vector<int64_t>{X.t.numel() + 8});
113 inputPtr = buffer->template mutable_data<uint8_t>() + 8;
114 memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.numel());
117 if (lastBatchSize_ != static_cast<size_t>(X.t.size(0)) ||
118 lastInputHeight_ != static_cast<size_t>(X.t.size(1)) ||
119 lastInputWidth_ != static_cast<size_t>(X.t.size(2)) ||
120 lastInputPointer_ != inputPtr ||
121 lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
122 const qnnp_status setupStatus = qnnp_setup_deconvolution2d_nhwc_q8(
123 this->qnnpackObject_,
129 Y->t.template mutable_data<uint8_t>(),
133 setupStatus == qnnp_status_success,
134 "failed to setup QNNPACK convolution object");
136 lastBatchSize_ =
static_cast<size_t>(X.t.size(0));
137 lastInputHeight_ =
static_cast<size_t>(X.t.size(1));
138 lastInputWidth_ =
static_cast<size_t>(X.t.size(2));
139 lastInputPointer_ = inputPtr;
140 lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
144 const qnnp_status runStatus =
145 qnnp_run_operator(this->qnnpackObject_,
nullptr );
147 const qnnp_status runStatus =
148 qnnp_run_operator(this->qnnpackObject_, threadpool);
151 runStatus == qnnp_status_success,
152 "failed to run QNNPACK convolution");
159 qnnp_operator_t qnnpackObject_{
nullptr};
161 size_t lastBatchSize_{0};
163 size_t lastInputHeight_{0};
165 size_t lastInputWidth_{0};
167 const void* lastInputPointer_{
nullptr};
169 void* lastOutputPointer_{
nullptr};
176 #endif // CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_ void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...