Caffe2 - C++ API
A deep learning, cross platform ML framework
int8_channel_shuffle_op.h
1 #ifndef CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
2 #define CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
3 
4 #include <qnnpack.h>
5 
6 #include "caffe2/core/context.h"
7 #include "caffe2/core/operator.h"
8 #include "caffe2/core/tensor_int8.h"
9 #include "caffe2/operators/conv_pool_op_base.h"
10 #include "caffe2/operators/quantized/int8_utils.h"
11 
12 namespace caffe2 {
13 
14 namespace int8 {
15 
16 class Int8ChannelShuffleOp final : public ConvPoolOpBase<CPUContext> {
17  public:
18  explicit Int8ChannelShuffleOp(const OperatorDef& operator_def, Workspace* ws)
19  : ConvPoolOpBase<CPUContext>(operator_def, ws), ws_(ws) {
20  OPERATOR_NEEDS_FEATURE(
21  this->order_ == StorageOrder::NHWC,
22  "Int8ChannelShuffleOp only supports NHWC order");
23  }
24 
26  if (this->qnnpackOperator_ != nullptr) {
27  qnnp_delete_operator(this->qnnpackOperator_);
28  this->qnnpackOperator_ = nullptr;
29  }
30  }
31 
32  bool RunOnDeviceWithOrderNHWC() override {
33  const auto& X = Inputs()[0]->template Get<Int8TensorCPU>();
34  auto* Y = Outputs()[0]->template GetMutable<Int8TensorCPU>();
35  Y->t.ResizeLike(X.t);
36  Y->scale = X.scale;
37  Y->zero_point = X.zero_point;
38  const int32_t Y_offset = this->template GetSingleArgument<int>("Y_zero_point", 0);
39  const float Y_scale = this->template GetSingleArgument<float>("Y_scale", 1.0f);
40  CHECK_EQ(Y_offset, X.zero_point);
41  CHECK_EQ(Y_scale, X.scale);
42  CHECK_GE(X.zero_point, std::numeric_limits<uint8_t>::min());
43  CHECK_LE(X.zero_point, std::numeric_limits<uint8_t>::max());
44 
45  const auto C = X.t.dim32(3);
46  const auto G = this->group_;
47  CAFFE_ENFORCE(C % G == 0, "");
48  const auto B = X.t.numel() / C;
49 
50  initQNNPACK();
51 
52  if (this->qnnpackOperator_ == nullptr) {
53  const qnnp_status createStatus = qnnp_create_channel_shuffle_nc_x8(
54  G /* groups */,
55  C / G /* group channels */,
56  0 /* flags */,
57  &this->qnnpackOperator_);
58  CAFFE_ENFORCE(
59  createStatus == qnnp_status_success,
60  "failed to create QNNPACK channel shuffle operator");
61  CAFFE_ENFORCE(this->qnnpackOperator_ != nullptr);
62  }
63 
64  const qnnp_status setupStatus = qnnp_setup_channel_shuffle_nc_x8(
65  this->qnnpackOperator_,
66  X.t.numel() / C /* batch size */,
67  X.t.template data<uint8_t>(),
68  C /* X stride */,
69  Y->t.template mutable_data<uint8_t>(),
70  C /* Y stride */);
71  CAFFE_ENFORCE(
72  setupStatus == qnnp_status_success,
73  "failed to setup QNNPACK channel shuffle operator");
74 
75 #ifdef FBCODE_CAFFE2
76  const qnnp_status runStatus =
77  qnnp_run_operator(this->qnnpackOperator_, nullptr /* thread pool */);
78 #else
79  pthreadpool_t threadpool =
80  reinterpret_cast<pthreadpool_t>(ws_->GetThreadPool());
81  const qnnp_status runStatus =
82  qnnp_run_operator(this->qnnpackOperator_, threadpool);
83 #endif
84  CAFFE_ENFORCE(
85  runStatus == qnnp_status_success,
86  "failed to run QNNPACK channel shuffle operator");
87 
88  return true;
89  }
90 
91  private:
92  Workspace* ws_;
93  // QNNPACK channel shuffle operator
94  qnnp_operator_t qnnpackOperator_{nullptr};
95 };
96 
97 } // namespace int8
98 
99 } // namespace caffe2
100 
101 #endif // CAFFE2_OPERATORS_INT8_CHANNEL_SHUFFLE_OP_H_
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
Definition: static.cpp:58