Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_conv_transpose_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
2 #define CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
3 
4 #include <qnnpack.h>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/tensor_int8.h"
9 #include "caffe2/operators/conv_op_shared.h"
10 #include "caffe2/operators/conv_transpose_unpool_op_base.h"
11 #include "caffe2/operators/quantized/int8_utils.h"
12 
13 namespace caffe2 {
14 
15 namespace int8 {
16 
17 class Int8ConvTransposeOp final : public ConvTransposeUnpoolBase<CPUContext> {
18  public:
19  USE_CONV_TRANSPOSE_UNPOOL_BASE_FUNCTIONS(CPUContext);
20  template <class... Args>
21  explicit Int8ConvTransposeOp(Args&&... args)
22  : ConvTransposeUnpoolBase(std::forward<Args>(args)...) {
23  OPERATOR_NEEDS_FEATURE(
24  this->order_ == StorageOrder::NHWC,
25  "Int8ConvTransposeOp only supports NHWC order");
26  createSharedBuffer<CPUContext>(ws_);
27  }
28 
29  ~Int8ConvTransposeOp() {
30  if (this->qnnpackObject_ != nullptr) {
31  qnnp_delete_operator(this->qnnpackObject_);
32  this->qnnpackObject_ = nullptr;
33  }
34  }
35 
36  bool RunOnDeviceWithOrderNHWC() override {
37  CAFFE_ENFORCE_EQ(Inputs().size(), 3);
38  const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
39  const auto& W = Inputs()[1]->template Get<Int8TensorCPU>();
40  const auto& B = Inputs()[2]->template Get<Int8TensorCPU>();
41  auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
42  const auto X_offset = -X.zero_point;
43  const auto W_offset = -W.zero_point;
44  const int32_t Y_offset =
45  this->template GetSingleArgument<int>("Y_zero_point", 0);
46  double Y_scale = this->template GetSingleArgument<float>("Y_scale", 1);
47  Y->scale = Y_scale;
48  Y->zero_point = Y_offset;
49 
50  const auto N = X.t.size(0);
51  const auto IH = X.t.size(1);
52  const auto IW = X.t.size(2);
53  const auto IC = X.t.size(3);
54 
55  CHECK_EQ(IC, W.t.size(0));
56  const auto KH = W.t.size(1);
57  const auto KW = W.t.size(2);
58  const auto OC = W.t.size(3);
59 
61  ReinitializeTensor(&(Y->t), sizes, at::dtype<uint8_t>().device(CPU));
62  CHECK_EQ(OC, Y->t.size(3));
63 
64  runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
65  initQNNPACK();
66 
67  pthreadpool_t threadpool =
68  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
69 
70  if (this->qnnpackObject_ == nullptr) {
71  const qnnp_status createStatus = qnnp_create_deconvolution2d_nhwc_q8(
72  pad_t(),
73  pad_r(),
74  pad_b(),
75  pad_l(),
76  adj_h(),
77  adj_w(),
78  KH,
79  KW,
80  stride_h(),
81  stride_w(),
82  1 /* dilation height */,
83  1 /* dilation width */,
84  1 /* groups */,
85  IC,
86  OC,
87  X.zero_point,
88  X.scale,
89  W.zero_point,
90  W.scale,
91 #ifndef _MSC_VER
92  W.t.template data<uint8_t>(),
93  B.t.template data<int32_t>(),
94 #else
95  W.t.data<uint8_t>(),
96  B.t.data<int32_t>(),
97 #endif
98  Y->zero_point,
99  Y->scale,
100  std::numeric_limits<uint8_t>::min(),
101  std::numeric_limits<uint8_t>::max(),
102  0 /* flags */,
103  &this->qnnpackObject_);
104  CAFFE_ENFORCE(
105  createStatus == qnnp_status_success,
106  "failed to create QNNPACK convolution object");
107  CAFFE_ENFORCE(this->qnnpackObject_ != nullptr);
108  }
109 
110  uint8_t* inputPtr = X.t.template mutable_data<uint8_t>();
111  if (IC < 8) {
112  buffer->Resize(std::vector<int64_t>{X.t.numel() + 8});
113  inputPtr = buffer->template mutable_data<uint8_t>() + 8;
114  memcpy(inputPtr, X.t.template data<uint8_t>(), X.t.numel());
115  }
116 
117  if (lastBatchSize_ != static_cast<size_t>(X.t.size(0)) ||
118  lastInputHeight_ != static_cast<size_t>(X.t.size(1)) ||
119  lastInputWidth_ != static_cast<size_t>(X.t.size(2)) ||
120  lastInputPointer_ != inputPtr ||
121  lastOutputPointer_ != Y->t.template mutable_data<uint8_t>()) {
122  const qnnp_status setupStatus = qnnp_setup_deconvolution2d_nhwc_q8(
123  this->qnnpackObject_,
124  X.t.size(0),
125  X.t.size(1),
126  X.t.size(2),
127  inputPtr,
128  X.t.size(3) /* input pixel stride */,
129  Y->t.template mutable_data<uint8_t>(),
130  Y->t.size(3) /* output pixel stride */,
131  nullptr /* threadpool */);
132  CAFFE_ENFORCE(
133  setupStatus == qnnp_status_success,
134  "failed to setup QNNPACK convolution object");
135 
136  lastBatchSize_ = static_cast<size_t>(X.t.size(0));
137  lastInputHeight_ = static_cast<size_t>(X.t.size(1));
138  lastInputWidth_ = static_cast<size_t>(X.t.size(2));
139  lastInputPointer_ = inputPtr;
140  lastOutputPointer_ = Y->t.template mutable_data<uint8_t>();
141  }
142 
143 #ifdef FBCODE_CAFFE2
144  const qnnp_status runStatus =
145  qnnp_run_operator(this->qnnpackObject_, nullptr /* thread pool */);
146 #else
147  const qnnp_status runStatus =
148  qnnp_run_operator(this->qnnpackObject_, threadpool);
149 #endif
150  CAFFE_ENFORCE(
151  runStatus == qnnp_status_success,
152  "failed to run QNNPACK convolution");
153  });
154  return true;
155  }
156 
157  private:
158  // QNNPACK convolution object
159  qnnp_operator_t qnnpackObject_{nullptr};
160  // batch size in the previous call to RunOnDeviceWithOrderNHWC
161  size_t lastBatchSize_{0};
162  // input height in the previous call to RunOnDeviceWithOrderNHWC
163  size_t lastInputHeight_{0};
164  // input width in the previous call to RunOnDeviceWithOrderNHWC
165  size_t lastInputWidth_{0};
166  // input pointer in the previous call to RunOnDeviceWithOrderNHWC
167  const void* lastInputPointer_{nullptr};
168  // output pointer in the previous call to RunOnDeviceWithOrderNHWC
169  void* lastOutputPointer_{nullptr};
170 };
171 
172 } // namespace int8
173 
174 } // namespace caffe2
175 
176 #endif // CAFFE2_OPERATORS_INT8_CONV_TRANSPOSE_OP_H_
void ReinitializeTensor(Tensor *tensor, at::IntArrayRef dims, at::TensorOptions options)
Reinitialize a Tensor to given dims and options if necessary, note that this will not do anything if ...
Definition: tensor.cc:127
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:58