Caffe2 - C++ API
A deep learning, cross platform ML framework
group_norm_dnnlowp_op.h
1 #pragma once
2 
3 #include <vector>
4 
5 #include "caffe2/operators/group_norm_op.h"
6 #include "caffe2/quantization/server/dnnlowp_op.h"
7 
8 namespace caffe2 {
9 
10 using GroupNormFP32Op = GroupNormOp<float, CPUContext>;
11 
12 template <typename T>
13 class GroupNormDNNLowPOp final : public DNNLowPOp<T, GroupNormFP32Op> {
14  public:
15  USE_OPERATOR_FUNCTIONS(CPUContext);
16  USE_DNNLOWP_OPERATOR_BASE_FUNCTIONS(T, GroupNormFP32Op);
17 
18  GroupNormDNNLowPOp(const OperatorDef& operator_def, Workspace* ws);
19 
20  bool RunOnDevice() override;
21 
22  private:
23  bool GetQuantizationParameters();
24 
25  void QuantizeGamma();
26 
27  void QuantizeGammaImpl();
28 
29  void QuantizeBeta();
30 
31  bool RunOnDeviceWithOrderNCHW();
32 
33  bool RunOnDeviceWithOrderNHWC();
34 
35  void QuantizedGroupMomentsNCHW(
36  int N,
37  int G,
38  int K,
39  int HxW,
40  const T* X,
41  int32_t* mu,
42  int32_t* rsig);
43 
44  void QuantizedGroupMomentsNHWC(
45  int N,
46  int G,
47  int K,
48  int HxW,
49  const T* X,
50  int32_t* mu,
51  int32_t* rsig);
52 
53  void DequantizedGroupMomentsNCHW(
54  int N,
55  int G,
56  int K,
57  int HxW,
58  const T* X,
59  float* mu,
60  float* rsig);
61 
62  void DequantizedGroupMomentsNHWC(
63  int N,
64  int G,
65  int K,
66  int HxW,
67  const T* X,
68  float* mu,
69  float* rsig);
70 
71  void ComputeQuantizedInvStd(
72  int N,
73  const float* var,
74  float* rsig,
75  int32_t* rsig_quantized);
76 
77  void ComputeQuantizedFusedParams(
78  int N,
79  int G,
80  int K,
81  const int32_t* mu,
82  const int32_t* rsig,
83  const int32_t* gamma,
84  const int32_t* beta,
85  int32_t* scale,
86  int32_t* bias);
87 
88  void ComputeDequantizedFusedParams(
89  int N,
90  int G,
91  int K,
92  const float* mu,
93  const float* rsig,
94  const float* gamma,
95  const float* beta,
96  float* scale,
97  float* bias);
98 
99  void AffineBatchChannelQuantizedNCHW(
100  int N,
101  int C,
102  int HxW,
103  const T* X,
104  const int32_t* scale,
105  const int32_t* bias,
106  T* Y);
107 
108  void AffineBatchChannelQuantizedNHWC(
109  int N,
110  int C,
111  int HxW,
112  const T* X,
113  const int32_t* scale,
114  const int32_t* bias,
115  T* Y);
116 
117  void AffineBatchChannelDequantizedNCHW(
118  int N,
119  int C,
120  int HxW,
121  const float* X,
122  const float* scale,
123  const float* bias,
124  float* Y);
125 
126  void AffineBatchChannelDequantizedNHWC(
127  int N,
128  int C,
129  int HxW,
130  const float* X,
131  const float* scale,
132  const float* bias,
133  float* Y);
134 
135  const bool is_test_;
136  const int group_;
137  const float epsilon_;
138  const StorageOrder order_;
139  const bool is_param_constant_;
140 
141  std::vector<int32_t> mu_quantized_;
142  std::vector<int32_t> rsig_quantized_;
143  std::vector<float> mu_dequantized_;
144  std::vector<float> rsig_dequantized_;
145  dnnlowp::TensorQuantizationParams rsig_qparams_;
146 
147  std::vector<int32_t> gamma_quantized_;
148  std::vector<int32_t> beta_quantized_;
149  std::vector<float> gamma_dequantized_;
150  std::vector<float> beta_dequantized_;
151  const int32_t* gamma_quantized_data_ = nullptr;
152  const int32_t* beta_quantized_data_ = nullptr;
153  const float* gamma_dequantized_data_ = nullptr;
154  const float* beta_dequantized_data_ = nullptr;
155 
156  std::vector<int32_t> scale_quantized_;
157  std::vector<int32_t> bias_quantized_;
158  std::vector<float> scale_dequantized_;
159  std::vector<float> bias_dequantized_;
160  dnnlowp::TensorQuantizationParams internal_qparams_;
161 
162  std::vector<float> X_dequantized_;
163  std::vector<int32_t> Y_int32_;
164 
165  float cached_X_qparams_scale_ = 0.0f;
166 
167  // Input: X, gamma, beta
168  // Output: Y, mu, inv_sig
169  INPUT_TAGS(INPUT, GAMMA, BETA);
170  OUTPUT_TAGS(OUTPUT, MU, INV_SIGMA);
171 };
172 
173 namespace internal {
174 
175 template <typename T>
176 void VectorMomentsAVX2(const int N, const T* src, int64_t* sum, int64_t* sumsq);
177 
178 void ComputeQuantizedFusedParamsAVX2(
179  const int N,
180  const int G,
181  const int K,
182  const int32_t X_zero_point,
183  const int32_t* mu,
184  const int32_t* rsig,
185  const int32_t* gamma,
186  int32_t* scale,
187  int32_t* bias);
188 
189 template <typename T>
190 void AffineBatchChannelAndRequantizeNCHWAVX2(
191  const int N,
192  const int C,
193  const int HxW,
194  const dnnlowp::RequantizationParams& params,
195  const T* X,
196  const int32_t* scale,
197  const int32_t* bias,
198  T* Y);
199 
200 template <typename T>
201 void AffineBatchChannelAndRequantizeNHWCAVX2(
202  const int N,
203  const int C,
204  const int HxW,
205  const dnnlowp::RequantizationParams& params,
206  const T* X,
207  const int32_t* scale,
208  const int32_t* bias,
209  T* Y);
210 
211 } // namespace internal
212 
213 } // namespace caffe2
The CPU Context, representing the bare minimum of what a Context class in Caffe2 should implement...
Definition: context.h:40
Workspace is a class that holds all the related objects created during runtime: (1) all blobs...
Definition: workspace.h:47
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: static.cpp:64
A convenient base class for C2 operators with DNNLOWP engine.
Definition: dnnlowp_op.h:77