Caffe2 - C++ API
A deep learning, cross platform ML framework
conv_dnnlowp_op.cc
1 #include "conv_dnnlowp_op.h"
2 
3 // #define DNNLOWP_MEASURE_TIME_BREAKDOWN
4 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
5 #include <chrono>
6 #endif
7 
8 #ifdef _OPENMP
9 #include <omp.h>
10 #endif
11 
12 #include "caffe2/core/tensor_int8.h"
13 #include "caffe2/utils/cpuid.h"
14 
15 #include <fbgemm/src/RefImplementations.h>
16 
17 #include "dnnlowp_op.h"
18 #include "dnnlowp_partition.h"
19 #include "fbgemm_pack_op.h"
20 #include "im2col_dnnlowp.h"
21 #include "mmio.h"
22 
23 C10_DEFINE_bool(
24  caffe2_dnnlowp_shared_int32_buffer,
25  false,
26  "Share intermediate int32 buffer across DNNLOWP Conv ops");
27 
28 C10_DEFINE_bool(
29  caffe2_dnnlowp_dump_tensors,
30  false,
31  "Dump quantized input and weight tensors used in Conv and FC operators "
32  "during the first iteration");
33 
34 C10_DECLARE_bool(caffe2_dnnlowp_force_slow_path);
35 
36 namespace caffe2 {
37 
38 using namespace std;
39 
40 template <typename T, bool ReluFused>
41 ConvDNNLowPOp<T, ReluFused>::ConvDNNLowPOp(
42  const OperatorDef& operator_def,
43  Workspace* ws)
44  : BaseType(operator_def, ws),
45  column_offsets_(make_shared<vector<int32_t>>()),
46  b_quantized_(make_shared<vector<int32_t>>()) {
47  in_qparams_.resize(1);
48 
49  // Create shared buffer mutex in the constructor
50  // to avoid race-condition in DAGNet.
51  if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
52  createSharedBuffer<CPUContext>(ws_);
53  }
54 
55  if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
56  this->CreateSharedInt32Buffer_();
57  }
58 
59  quantize_groupwise_ =
60  this->template GetSingleArgument<bool>("quantize_groupwise", false);
61 }
62 
63 template <typename T, bool ReluFused>
64 ConvDNNLowPOp<T, ReluFused>::~ConvDNNLowPOp() {}
65 
66 template <typename T, bool ReluFused>
67 dnnlowp::TensorQuantizationParams&
68 ConvDNNLowPOp<T, ReluFused>::FilterQuantizationParams(int group_id) {
69  return filter_qparams_[quantize_groupwise_ ? group_id : 0];
70 }
71 
72 template <typename T, bool ReluFused>
73 dnnlowp::RequantizationParams&
74 ConvDNNLowPOp<T, ReluFused>::RequantizationParams(int group_id) {
75  return requantization_params_[quantize_groupwise_ ? group_id : 0];
76 }
77 
78 // FIXME : code duplication with
79 // ConvDNNLowPPackWeightOp::TakeDepthWise3x3FastPath_
80 template <typename T, bool ReluFused>
81 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
82  const Tensor& X = InputTensorCPU_(INPUT);
83  return this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
84  !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
85  this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
86  stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
87  dilation_h() == 1 && dilation_w() == 1 && pad_t() == 1 && pad_b() == 1 &&
88  pad_l() == 1 && pad_r() == 1 && GetCpuId().avx2();
89 }
90 
91 // FIXME : code duplication with
92 // ConvDNNLowPPackWeightOp::TakeDepthWise3x3x3FastPath_
93 template <typename T, bool ReluFused>
94 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
95  const Tensor& X = InputTensorCPU_(INPUT);
96  bool ret = this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
97  !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
98  this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
99  this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
100  this->stride_[0] == this->stride_[1] &&
101  this->stride_[0] == this->stride_[2] &&
102  (this->stride_[0] == 1 || this->stride_[0] == 2) &&
103  this->dilation_[0] == 1 && this->dilation_[1] == 1 &&
104  this->dilation_[2] == 1 &&
105  accumulate(
106  this->pads_.begin(), this->pads_.end(), 1, multiplies<int>()) == 1 &&
107  GetCpuId().avx2();
108  return ret;
109 }
110 
111 template <typename T, bool ReluFused>
112 bool ConvDNNLowPOp<T, ReluFused>::TakeGConvFastPath_() {
113  const Tensor& X = InputTensorCPU_(INPUT);
114  if (this->order_ != StorageOrder::NHWC || !is_same<T, uint8_t>::value ||
115  !X.template IsType<T>() || this->kernel_.size() != 2) {
116  return false;
117  }
118 
119  auto& filter = InputTensorCPU_(FILTER);
120  const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
121  const int M = filter.dim32(0);
122  fbgemm::conv_param_t<> conv_p(
123  N,
124  C,
125  M,
126  {X.dim32(1), X.dim32(2)},
127  group_,
128  {this->kernel_[0], this->kernel_[1]},
129  {this->stride_[0], this->stride_[1]},
130  {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
131 
132  return fbgemm::fbgemmOptimizedGConv(conv_p);
133 }
134 
135 template <typename T, bool ReluFused>
136 int ConvDNNLowPOp<T, ReluFused>::KernelDim_() {
137  int kernel_dim;
138  const Tensor& X = InputTensorCPU_(INPUT);
139  const auto& filter = InputTensorCPU_(FILTER);
140 
141  int C;
142  int filter_offset;
143  if (ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NCHW) {
144  C = X.dim32(1);
145  filter_offset = 2;
146  } else {
147  C = X.dim32(X.dim() - 1);
148  filter_offset = 1;
149  }
150 
151  int kernel_dims_size = 1;
152  for (int i = 0; i < this->kernel_.size(); ++i) {
153  CAFFE_ENFORCE_EQ(filter.dim32(i + filter_offset), kernel_[i]);
154  kernel_dims_size *= kernel_[i];
155  }
156  kernel_dim = C / group_ * kernel_dims_size;
157 
158  return kernel_dim;
159 }
160 
161 template <typename T, bool ReluFused>
163  return accumulate(
164  this->kernel_.begin(),
165  this->kernel_.end(),
166  1,
167  multiplies<int>()) == 1 &&
168  accumulate(
169  this->stride_.begin(), this->stride_.end(), 1, multiplies<int>()) ==
170  1 &&
171  accumulate(
172  this->dilation_.begin(),
173  this->dilation_.end(),
174  1,
175  multiplies<int>()) == 1 &&
176  accumulate(this->pads_.begin(), this->pads_.end(), 0) == 0;
177 }
178 
179 template <typename T, bool ReluFused>
181  if (TakeDepthWise3x3FastPath_() || TakeDepthWise3x3x3FastPath_() ||
182  TakeGConvFastPath_()) {
183  return true;
184  }
185 
186  if (Wq_packed_ &&
187  accumulate(
188  this->dilation_.begin(),
189  this->dilation_.end(),
190  1,
191  multiplies<int>()) == 1) {
192  // im2col fusion
193  return true;
194  }
195 
196  return IsConvGEMM_();
197 }
198 
199 template <typename T, bool ReluFused>
201  const auto& filter = InputTensorCPU_(FILTER);
202  int kernel_dim = KernelDim_();
203  int M = filter.dim32(0);
204 
205  // Pre-compute row_offset / column_offset
206  vector<int>& offsets =
207  StorageOrder::NCHW == ConvPoolOpBase<CPUContext>::order_
208  ? row_offsets_
209  : *column_offsets_;
210 
211  if (offsets.empty()) {
212  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
213  const auto& packed_filter =
214  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
215  column_offsets_ = packed_filter.column_offsets;
216  } else {
217  ComputeColumnOffsets<T_signed>(
218  kernel_dim, M, W_quantized_.data(), filter_qparams_, offsets);
219  }
220  }
221 }
222 
223 template <typename T, bool ReluFused>
225  using namespace dnnlowp;
226 
227  const auto& filter = InputTensorCPU_(FILTER);
228  int M = filter.dim32(0);
229 
230  bool has_packed_bias =
231  this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
232  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).bias.get();
233  bool has_bias = InputSize() == 3 || has_packed_bias;
234 
235  // Quantize bias
236  if (has_bias &&
237  (!b_quantized_data_ ||
238  in_qparams_[INPUT].scale != in_qparams_scale_old_)) {
239  if (has_packed_bias) {
240  const auto& packed_filter =
241  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
242  b_quantized_ = packed_filter.bias;
243  b_quantized_data_ = b_quantized_->data();
244  } else {
245  const auto& bias = InputTensorCPU_(BIAS);
246  if (this->template InputIsType<int8::Int8TensorCPU>(BIAS)) {
247  TensorQuantizationParams bias_qparams;
248  bias_qparams.scale =
249  this->template Input<int8::Int8TensorCPU>(BIAS).scale;
250  bias_qparams.zero_point =
251  this->template Input<int8::Int8TensorCPU>(BIAS).zero_point;
252  CAFFE_ENFORCE_LE(
253  std::abs(
254  bias_qparams.scale -
255  in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale),
256  1e-4);
257  CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
258  b_quantized_data_ = bias.template data<int32_t>();
259  } else {
260  const float* b_data = bias.template data<float>();
261  b_quantized_->resize(bias.numel());
262  for (int g = 0; g < filter_qparams_.size(); ++g) {
263  int i_begin = g * (M / filter_qparams_.size());
264  int i_end = i_begin + (M / filter_qparams_.size());
265  for (int i = i_begin; i < i_end; ++i) {
266  (*b_quantized_)[i] = fbgemm::Quantize<int32_t>(
267  b_data[i],
268  0,
269  in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
270  32,
271  true /* signed */);
272  }
273  }
274  b_quantized_data_ = b_quantized_->data();
275  }
276  in_qparams_scale_old_ = in_qparams_[INPUT].scale;
277  }
278 
279  CAFFE_ENFORCE(b_quantized_data_);
280  }
281 }
282 
283 template <typename T, bool ReluFused>
285  using namespace dnnlowp;
286 
287  // Quantize W if not done already
288  int kernel_dim = KernelDim_();
289  const auto& filter = InputTensorCPU_(FILTER);
290  int M = filter.dim32(0);
291 
292  bool packW = ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NHWC &&
293  !Acc16() && is_same<T, uint8_t>::value && GetCpuId().avx2() &&
294  !FLAGS_caffe2_dnnlowp_force_slow_path;
295 
296  bool depthwise_3x3_fast_path = false, depthwise_3x3x3_fast_path = false,
297  gconv_fast_path = false;
298  if (TakeDepthWise3x3FastPath_()) {
299  depthwise_3x3_fast_path = true;
300  packW = false;
301  } else if (TakeDepthWise3x3x3FastPath_()) {
302  depthwise_3x3x3_fast_path = true;
303  packW = false;
304  } else if (TakeGConvFastPath_()) {
305  gconv_fast_path = true;
306  packW = false;
307  }
308 
309  if ((depthwise_3x3_fast_path && !Wq_depthwise_3x3_packed_) ||
310  (depthwise_3x3x3_fast_path && !Wq_depthwise_3x3x3_packed_) ||
311  (gconv_fast_path && !Wq_gconv_packed_) || (packW && !Wq_packed_) ||
312  (!packW && W_quantized_.empty())) {
313  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
314  CAFFE_ENFORCE_EQ(
316  StorageOrder::NHWC,
317  "Pre-packed weight only works with NHWC layout");
318 
319  const auto& packed_filter =
320  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
321  filter_qparams_ = packed_filter.qparams;
322  } else {
323  filter_qparams_.resize(quantize_groupwise_ ? group_ : 1);
324  QuantizeWeight<T>(
325  InputBlob(FILTER),
326  kernel_dim,
327  M,
328  filter_qparams_,
329  W_quantized_,
330  qfactory_.get());
331 
332  if (this->template InputIsType<int8::Int8TensorCPU>(FILTER) &&
333  quantize_groupwise_) {
334  static int log_occurences = 0;
335  if (log_occurences < 32) {
336  ++log_occurences;
337  LOG(WARNING) << "Cannot do group-wise quantization for "
338  "pre-quantized weight "
339  << this->debug_def().input(FILTER);
340  }
341  }
342  }
343 
344  filter_zero_points_.resize(filter_qparams_.size());
345  requantization_params_.resize(filter_qparams_.size());
346  requantization_multipliers_.resize(filter_qparams_.size());
347  for (int i = 0; i < filter_qparams_.size(); ++i) {
348  filter_zero_points_[i] = filter_qparams_[i].zero_point;
349  }
350 
351  if (depthwise_3x3_fast_path) {
352  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
353  const auto& packed_filter =
354  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
355  Wq_depthwise_3x3_packed_ = packed_filter.W_depthwise_3x3;
356  } else {
357  Wq_depthwise_3x3_packed_.reset(new fbgemm::Packed3x3ConvMatrix(
358  group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
359  }
360  } else if (depthwise_3x3x3_fast_path) {
361  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
362  const auto& packed_filter =
363  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
364  Wq_depthwise_3x3x3_packed_ = packed_filter.W_depthwise_3x3x3;
365  } else {
366  Wq_depthwise_3x3x3_packed_.reset(new fbgemm::Packed3x3x3ConvMatrix(
367  group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
368  }
369  } else if (gconv_fast_path) {
370  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
371  const auto& packed_filter =
372  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
373  Wq_gconv_packed_ = packed_filter.W_gconv;
374  } else {
375  const Tensor& X = InputTensorCPU_(INPUT);
376  const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
377 
378  fbgemm::conv_param_t<> conv_p(
379  N,
380  C,
381  M,
382  {X.dim32(1), X.dim32(2)},
383  group_,
384  {this->kernel_[0], this->kernel_[1]},
385  {this->stride_[0], this->stride_[1]},
386  {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
387 
388  Wq_gconv_packed_.reset(new fbgemm::PackWeightMatrixForGConv<int8_t>(
389  fbgemm::matrix_op_t::Transpose,
390  conv_p,
391  reinterpret_cast<const int8_t*>(W_quantized_.data())));
392  }
393  } else if (packW) {
394  if (this->template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
395  const auto& packed_filter =
396  this->template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
397  Wq_packed_ = packed_filter.W;
398  } else {
399  // fast path using fbgemm
400  Wq_packed_.reset(new fbgemm::PackBMatrix<int8_t>(
401  fbgemm::matrix_op_t::Transpose,
402  group_ * kernel_dim,
403  M / group_,
404  reinterpret_cast<const int8_t*>(W_quantized_.data()),
405  kernel_dim, // ld
406  nullptr, // pmat
407  group_));
408  }
409  } else {
410  string reason;
411  if (ConvPoolOpBase<CPUContext>::order_ != StorageOrder::NHWC) {
412  reason = "fbgemm only supports NHWC layout";
413  } else if (!is_same<T, uint8_t>::value) {
414  reason = "fbgemm only supports 8-bit integers";
415  } else if (!GetCpuId().avx2()) {
416  reason = "fbgemm only supports AVX2+";
417  } else if (Acc16()) {
418  reason = "";
419  } else if (FLAGS_caffe2_dnnlowp_force_slow_path) {
420  reason = "slow path enforced";
421  } else {
422  assert(false);
423  }
424  if (!reason.empty()) {
425  static int log_occurences = 0;
426  if (log_occurences < 32) {
427  ++log_occurences;
428  LOG(WARNING) << "Conv with weight " << this->debug_def().input(FILTER)
429  << " falls back to slow path because " << reason;
430  }
431  }
432  }
433  }
434 }
435 
439 template <typename T, bool ReluFused>
441  using namespace dnnlowp;
442 
443  if (!this->arguments_parsed_) {
444  bool dequantize_output;
445  ParseDNNLowPOperatorArguments(
446  this, &dequantize_output, &measure_quantization_error_, &followed_by_);
447  CAFFE_ENFORCE_EQ(
448  dequantize_output,
449  false,
450  "Conv DNNLOWP operators don't support dequantize_output");
451 
452  if (ReluFused) {
453  // It's actually fused with Relu not followed by but setting this to make
454  // sure quantization error is correctly measured in
455  // this->MeasureQuantizationError_
456  followed_by_ = "Relu";
457  AdjustOutputTensorQuantizationParamsWithFollowedBy(this, followed_by_);
458  }
459  this->arguments_parsed_ = true;
460  }
461 
462  // Choose quantization for X
463  in_qparams_[INPUT] =
464  GetInputTensorQuantizationParamsOf(this, INPUT, qfactory_.get());
465 
466  QuantizeWeight_();
467  PreComputeRowColumnOffsets_();
468  if (Wq_packed_ && !FLAGS_caffe2_dnnlowp_dump_tensors) {
469  // From here, W_quantized_ is not used anymore when we have Wq_packed_
470  vector<T_signed>().swap(W_quantized_);
471  }
472 
473  QuantizeBias_();
474 
475  bool fp32_executed = false;
476  if (HasStaticQuantization(this)) {
477  out_qparams_ = GetStaticQuantizationParamsOf(this, 0);
478  } else {
479  // If quantization parameters are not chosen beforehand, run reference
480  // Conv op in fp32 to choose quantization for Y.
481  Fp32Op_()->DequantizeInput();
482  Fp32Op_()->Get()->RunOnDevice();
483  out_qparams_ = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
484  fp32_executed = true;
485  }
486 
487  for (int g = 0; g < filter_qparams_.size(); ++g) {
488  float real_multiplier = in_qparams_[INPUT].scale *
489  FilterQuantizationParams(g).scale / out_qparams_.scale;
490  requantization_params_[g] = qfactory_->ChooseRequantizationMultiplier(
491  real_multiplier, out_qparams_);
492  requantization_multipliers_[g] = requantization_params_[g].real_multiplier;
493  }
494 
495  if (measure_quantization_error_ && Fp32Op_() && !fp32_executed) {
496  // to measure quantization error, run ref impl.
497  Fp32Op_()->DequantizeInput();
498  Fp32Op_()->Get()->RunOnDevice();
499  }
500 
501  return true;
502 }
503 
504 template <typename T, bool ReluFused>
506  const T* col_buffer_data,
507  int32_t* Y_int32,
508  T* Y_data,
509  size_t i_offset,
510  int group_id) {
511  auto& filter = InputTensorCPU_(FILTER);
512  const int M = filter.dim32(0);
513  int kernel_dim = KernelDim_();
514 
515  Tensor* Y = OutputTensorCPU_(0);
516  const int Y_HxW = this->GetDimsSize(*Y);
517 
518  // See batch_matmul_dnnlowp_op.cc to why we compute column_offsets,
519  // row_offset, and const_offset in this way.
520  int tid = dnnlowp_get_thread_num();
521  int32_t* column_offsets = column_offsets_->data() + tid * Y_HxW;
522 
523  const dnnlowp::TensorQuantizationParams& filter_qparams =
524  FilterQuantizationParams(group_id);
525  for (int j = 0; j < Y_HxW; ++j) {
526  int sum = 0;
527  for (int k = 0; k < kernel_dim; ++k) {
528  sum += col_buffer_data[k * Y_HxW + j];
529  }
530  column_offsets[j] = sum * filter_qparams.zero_point;
531  }
532 
533  for (int i = 0; i < M / group_; ++i) {
534  int32_t row_offset = row_offsets_[i_offset + i];
535  row_offset *= -in_qparams_[INPUT].zero_point;
536  if (b_quantized_data_) {
537  row_offset += b_quantized_data_[i_offset + i];
538  }
539  for (int j = 0; j < Y_HxW; ++j) {
540  int32_t raw = Y_int32[i * Y_HxW + j] + row_offset - column_offsets[j];
541  if (ReluFused) {
542  raw = std::max(0, raw);
543  }
544  Y_data[i * Y_HxW + j] =
545  fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
546  }
547  }
548 }
549 
550 template <typename T, bool ReluFused>
552  VLOG(2) << "Running DNNLOWP Conv";
553 
554  using namespace dnnlowp;
555 
556  // Get quantization parameters
557  if (!GetQuantizationParameters_()) {
558  return false;
559  }
560 
561  const Tensor& X = InputTensorCPU_(INPUT);
562  auto& filter = InputTensorCPU_(FILTER);
563  const int N = X.dim32(0), C = X.dim32(1);
564  CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
565  const int M = filter.dim32(0);
566  CAFFE_ENFORCE_EQ(
567  C,
568  filter.dim32(1) * group_,
569  "Convolution op: input channels does not match: # of input channels ",
570  C,
571  " is not equal to kernel channels * group:",
572  filter.dim32(1),
573  "*",
574  group_);
575  CAFFE_ENFORCE_EQ(
576  M % group_,
577  0,
578  "The number of output channels is not divisible by group.");
579 
580  auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
581  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
582 
583  const vector<int> input_dims = GetDims(X);
584  const vector<int> output_dims = GetDims(*Y);
585  const int X_HxW = this->GetDimsSize(X);
586  const int Y_HxW = this->GetDimsSize(*Y);
587 
588  // The dimension of each kernel
589  const int kernel_dim = KernelDim_();
590 
591  vector<int> img_shape;
592  img_shape.assign(X.sizes().begin() + 1, X.sizes().end());
593 
594  vector<int> buffer_shape;
595  buffer_shape.push_back(kernel_dim);
596  buffer_shape.insert(
597  buffer_shape.end(), output_dims.begin(), output_dims.end());
598  buffer_shape.insert(buffer_shape.begin(), dnnlowp_get_max_threads());
599 
600  if (this->kernel_.size() != 2) {
601  SetDeviceTensor(img_shape, &img_shape_device_);
602  SetDeviceTensor(buffer_shape, &col_buffer_shape_device_);
603  }
604 
605  const int col_buffer_size = kernel_dim * Y_HxW;
606 
607  // The offset corresponding to a single input image, and a single output
608  // image.
609  const int input_offset = C / group_ * X_HxW;
610 
611  // The col buffer is stored in CHW order as well - kernel_dim, and the
612  // height and width.
613  const T* Xdata = X.template data<T>();
614 
615  // We must not call mutable_data inside omp region
616  T* Y_data_T = Y->template mutable_data<T>();
617  column_offsets_->resize(Y_HxW * dnnlowp_get_max_threads());
618 
619  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
620  col_buffer->Resize(buffer_shape);
621  vector<int> buffer_shape_per_thread(
622  buffer_shape.begin() + 1, buffer_shape.end());
623  T* col_buffer_data = col_buffer->template mutable_data<T>();
624 
625  Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
626 
627  // Im2Col, followed by gemm.
628 #ifdef _OPENMP
629 #pragma omp parallel for
630 #endif
631  for (int image_id = 0; image_id < N; ++image_id) {
632  int tid = dnnlowp_get_thread_num();
633  for (int group_id = 0; group_id < group_; ++group_id) {
634  if (this->kernel_.size() == 2) {
635  math::Im2ColNCHW<T>(
636  C / group_,
637  input_dims[0],
638  input_dims[1],
639  kernel_h(),
640  kernel_w(),
641  dilation_h(),
642  dilation_w(),
643  pad_t(),
644  pad_l(),
645  pad_b(),
646  pad_r(),
647  stride_h(),
648  stride_w(),
649  Xdata + (group_ * image_id + group_id) * input_offset,
650  col_buffer_data + tid * col_buffer_size,
651  &context_,
652  in_qparams_[INPUT].zero_point);
653  } else {
654  math::Im2ColNdNCHW<T>(
655  this->kernel_.size(),
656  C * X_HxW,
657  col_buffer_size,
658  img_shape.data(),
659  buffer_shape_per_thread.data(),
660  this->kernel_.data(),
661  this->stride_.data(),
662  this->dilation_.data(),
663  this->pads_.data(),
664  Xdata + (group_ * image_id + group_id) * input_offset,
665  col_buffer_data + tid * col_buffer_size,
666  &context_,
667  in_qparams_[INPUT].zero_point);
668  }
669 
670  // quantize col_buffer
671  T* col_buffer_private = col_buffer_data + tid * col_buffer_size;
672 
673  int32_t* Y_int32_temp =
674  Y_int32->data() + ((M / group_) * group_id + M * tid) * Y_HxW;
675  T_signed* W_quantized_group =
676  W_quantized_.data() + (M / group_) * group_id * kernel_dim;
677 
678  for (int i = 0; i < M / group_; ++i) {
679  for (int j = 0; j < Y_HxW; ++j) {
680  int32_t sum = 0;
681  for (int k = 0; k < kernel_dim; ++k) {
682  int w = W_quantized_group[i * kernel_dim + k];
683  int x = col_buffer_private[k * Y_HxW + j];
684  sum += w * x;
685  }
686  Y_int32_temp[i * Y_HxW + j] = sum;
687  } // j
688  } // i
689 
690  RunOnDeviceEpilogueNCHW_(
691  col_buffer_private,
692  Y_int32_temp,
693  Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
694  M / group_ * group_id,
695  group_id);
696  } // for each group
697  } // for each image_id
698 
699  PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
700  MeasureQuantizationError_();
701  }; // f
702 
703  this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
704 
705  return true;
706 } // RunOnDeviceWithOrderNCHW
707 
708 template <typename T, bool ReluFused>
710  const T* col_buffer_data,
711  int32_t* Y_int32) {
712  const Tensor& X = InputTensorCPU_(INPUT);
713  auto& filter = InputTensorCPU_(FILTER);
714  Tensor* Y = OutputTensorCPU_(0);
715  const int N = X.dim32(0);
716  const int M = filter.dim32(0);
717  int kernel_dim = KernelDim_();
718  const int Y_HxW = this->GetDimsSize(*Y);
719 
720  // Adjust with bias and zero_point and then requantize
721  // See batch_matmul_dnnlowp_op.cc to why we compute column_offsets,
722  // row_offset, and const_offset in this way.
723  int32_t A_zero_point = in_qparams_[INPUT].zero_point;
724 
725  if (!dnnlowp::HasStaticQuantization(this)) {
726  if (quantize_groupwise_) {
727  static int log_occurences = 0;
728  if (log_occurences < 32) {
729  ++log_occurences;
730  LOG(WARNING) << "Cannot do group-wise quantization without "
731  "static quantization of activations for "
732  << this->debug_def().output(0);
733  }
734  }
735 
736  int32_t Y_min = numeric_limits<int32_t>::max();
737  int32_t Y_max = numeric_limits<int32_t>::min();
738 
739 #ifdef _OPENMP
740 #pragma omp parallel for reduction(min : Y_min), reduction(max : Y_max)
741 #endif
742  for (int i = 0; i < N * Y_HxW; ++i) {
743  for (int group_id = 0; group_id < group_; ++group_id) {
744  int32_t row_offset = 0;
745  for (int k = 0; k < kernel_dim; ++k) {
746  row_offset +=
747  col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
748  }
749  row_offset *= FilterQuantizationParams(0).zero_point;
750 
751  for (int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
752  ++j) {
753  int32_t raw = Y_int32[i * M + j] -
754  A_zero_point * (*column_offsets_)[j] - row_offset;
755  if (b_quantized_data_) {
756  raw += b_quantized_data_[j];
757  }
758  Y_min = std::min(Y_min, raw);
759  Y_max = std::max(Y_max, raw);
760  }
761  } // for each group
762  } // for each row i
763 
764  if (ReluFused) {
765  Y_min = std::max(0, Y_min);
766  Y_max = std::max(0, Y_max);
767  }
768 
769  float Y_scale =
770  in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale;
771  out_qparams_ =
772  qfactory_->ChooseQuantizationParams(Y_scale * Y_min, Y_scale * Y_max);
773 
774  float real_multiplier = Y_scale / out_qparams_.scale;
775  requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
776  real_multiplier, out_qparams_);
777  requantization_multipliers_[0] = requantization_params_[0].real_multiplier;
778  }
779 
780  int32_t C_zero_point = out_qparams_.zero_point;
781 
782  T* Ydata = Y->template mutable_data<T>();
783 
784  using namespace fbgemm;
785  if (is_same<T, uint8_t>::value && GetCpuId().avx2()) {
786 #ifdef _OPENMP
787 #pragma omp parallel for
788 #endif
789  for (int i = 0; i < N * Y_HxW; ++i) {
790  for (int group_id = 0; group_id < group_; ++group_id) {
791  int32_t row_offset;
792  row_offsets_u8acc32_ref(
793  1,
794  kernel_dim,
795  group_ * kernel_dim,
796  reinterpret_cast<const uint8_t*>(
797  col_buffer_data + (i * group_ + group_id) * kernel_dim),
798  &row_offset);
799 
800  int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
801  float C_multiplier = RequantizationParams(group_id).real_multiplier;
802 
803  requantize_u8acc32_ref(
804  1,
805  M / group_,
806  M,
807  Y_int32 + i * M + group_id * (M / group_),
808  reinterpret_cast<uint8_t*>(Ydata + i * M + group_id * (M / group_)),
809  &C_multiplier,
810  C_zero_point,
811  A_zero_point,
812  &B_zero_point,
813  &row_offset,
814  column_offsets_->data() + group_id * (M / group_),
815  b_quantized_data_ ? b_quantized_data_ + group_id * (M / group_)
816  : nullptr,
817  M / group_,
818  ReluFused);
819  } // for each group
820  } // for each row i
821  } else {
822 #ifdef _OPENMP
823 #pragma omp parallel for
824 #endif
825  for (int i = 0; i < N * Y_HxW; ++i) {
826  for (int group_id = 0; group_id < group_; ++group_id) {
827  int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
828  int32_t row_offset = 0;
829  for (int k = 0; k < kernel_dim; ++k) {
830  row_offset +=
831  col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
832  }
833  row_offset *= B_zero_point;
834 
835  for (int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
836  ++j) {
837  int32_t raw = Y_int32[i * M + j] -
838  A_zero_point * (*column_offsets_)[j] - row_offset;
839  if (b_quantized_data_) {
840  raw += b_quantized_data_[j];
841  }
842 
843  Ydata[i * M + j] =
844  fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
845  if (ReluFused) { // static if
846  Ydata[i * M + j] =
847  std::max<int32_t>(C_zero_point, Ydata[i * M + j]);
848  }
849  }
850  } // for each group
851  } // for each row i
852  } // !__AVX2__
853 
854  dnnlowp::PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
855 }
856 
857 template <typename T, bool ReluFused>
859  int* group_begin,
860  int* group_end,
861  int* i_begin,
862  int* i_end,
863  int num_groups,
864  int m,
865  int nthreads,
866  int thread_id) {
867  // Make sure i_per_thread is a multiple of 32 because
868  // cblas_gemm_compute_u8s8s32_acc16 performs the best when M is a multiple
869  // of 32.
871  num_groups,
872  m,
873  nthreads,
874  thread_id,
875  group_begin,
876  group_end,
877  i_begin,
878  i_end,
879  32);
880 }
881 
882 template <typename T, bool ReluFused>
884  const Tensor& X = InputTensorCPU_(INPUT);
885  Tensor* Y = OutputTensorCPU_(0);
886  int ndim = X.dim();
887  const int N = X.dim32(0), C = X.dim32(ndim - 1);
888 
889  const int kernel_dim = KernelDim_();
890  // The offset corresponding to a single input image, and a single output
891  // image.
892  const int X_HxW = this->GetDimsSize(X);
893  const int input_offset = X_HxW * C;
894  const int Y_HxW = this->GetDimsSize(*Y);
895 
896  const T* Xdata = X.template data<T>();
897 
898  vector<int> buffer_shape(ndim);
899  for (auto i = 0; i < ndim - 1; ++i) {
900  buffer_shape[i] = Y->dim32(i);
901  }
902  buffer_shape[ndim - 1] = kernel_dim * group_;
903 
904  col_buffer->Resize(buffer_shape);
905 
906  T* col_buffer_data = col_buffer->template mutable_data<T>();
907 
908 #ifdef _OPENMP
909 #pragma omp parallel for if (N > 1)
910 #endif
911  for (int image_id = 0; image_id < N; ++image_id) {
912  if (this->kernel_.size() <= 2) {
913  math::Im2ColNHWC<T>(
914  C,
915  X.dim32(1),
916  this->kernel_.size() == 2 ? X.dim32(2) : 1,
917  kernel_h(),
918  this->kernel_.size() == 2 ? kernel_w() : 1,
919  dilation_h(),
920  this->kernel_.size() == 2 ? dilation_w() : 1,
921  pad_t(),
922  this->kernel_.size() == 2 ? pad_l() : 0,
923  this->kernel_.size() == 2 ? pad_b() : pad_l(),
924  this->kernel_.size() == 2 ? pad_r() : 0,
925  stride_h(),
926  this->kernel_.size() == 2 ? stride_w() : 1,
927  Xdata + image_id * input_offset,
928  col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
929  &context_,
930  group_,
931  in_qparams_[INPUT].zero_point);
932  } else {
933  math::Im2Col3DNHWC<T>(
934  C,
935  X.dim32(1), // num_frames
936  X.dim32(2), // H
937  X.dim32(3), // W
938  this->kernel_[0],
939  this->kernel_[1],
940  this->kernel_[2],
941  this->dilation_[0],
942  this->dilation_[1],
943  this->dilation_[2],
944  this->pads_[0],
945  this->pads_[1],
946  this->pads_[2],
947  this->pads_[3],
948  this->pads_[4],
949  this->pads_[5],
950  this->stride_[0],
951  this->stride_[1],
952  this->stride_[2],
953  Xdata + image_id * input_offset,
954  col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
955  &context_,
956  group_,
957  in_qparams_[INPUT].zero_point);
958  }
959  }
960 
961  return col_buffer->template data<T>();
962 }
963 
964 template <typename T, typename T_signed>
965 static void conv_nhwc_ref_(
966  int group_id,
967  int num_groups,
968  int i_begin,
969  int i_end,
970  int M,
971  int kernel_dim,
972  const T* col_buffer,
973  const T_signed* W,
974  int32_t* Y) {
975  for (int i = i_begin; i < i_end; ++i) {
976  for (int j = group_id * (M / num_groups);
977  j < (group_id + 1) * (M / num_groups);
978  ++j) {
979  int32_t sum = 0;
980  for (int k = 0; k < kernel_dim; ++k) {
981  int w = W[j * kernel_dim + k];
982  int x = col_buffer[(i * num_groups + group_id) * kernel_dim + k];
983  sum += w * x;
984  }
985  Y[i * M + j] = sum;
986  }
987  }
988 }
989 
990 template <typename T, bool ReluFused>
991 template <typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
993  PackAMatrix& packA,
994  vector<int32_t>* Y_int32,
995  uint8_t* Y_uint8_data) {
996  // This function is called within an OpenMP region
997  auto& filter = InputTensorCPU_(FILTER);
998  const int M = filter.dim32(0);
999 
1000  int nthreads = dnnlowp_get_num_threads();
1001  int tid = dnnlowp_get_thread_num();
1002 
1003  using namespace fbgemm;
1004  DoNothing<> doNothingObj{};
1005  ReQuantizeOutput<ReluFused, Q_GRAN> outputProcObj(
1006  doNothingObj,
1007  requantization_multipliers_.data(),
1008  out_qparams_.zero_point,
1009  in_qparams_[INPUT].zero_point,
1010  filter_zero_points_.data(),
1011  packA.getRowOffsetBuffer(),
1012  column_offsets_->data(),
1013  b_quantized_data_,
1014  M,
1015  group_);
1016 
1017  fbgemmPacked(
1018  packA,
1019  *Wq_packed_,
1020  Y_uint8_data,
1021  Y_int32->data(),
1022  M,
1023  outputProcObj,
1024  tid,
1025  nthreads);
1026 }
1027 
1028 template <typename T, bool ReluFused>
1030  const T* col_buffer_data,
1031  vector<int32_t>* Y_int32) {
1032  const Tensor& X = InputTensorCPU_(INPUT);
1033  auto& filter = InputTensorCPU_(FILTER);
1034  Tensor* Y = OutputTensorCPU_(0);
1035  const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
1036  const int M = filter.dim32(0);
1037  const int kernel_dim = KernelDim_();
1038  const int Y_HxW = this->GetDimsSize(*Y);
1039 
1040  if (FLAGS_caffe2_dnnlowp_dump_tensors) {
1041  // Dump input activation
1042  StoreMatrixInMatrixMarketFormat(
1043  N * Y_HxW * group_,
1044  kernel_dim,
1045  col_buffer_data,
1046  this->debug_def().input(INPUT));
1047 
1048  // Dump weight
1049  StoreMatrixInMatrixMarketFormat(
1050  group_ * M,
1051  kernel_dim,
1052  W_quantized_.data(),
1053  this->debug_def().input(FILTER));
1054  }
1055 
1056  using namespace fbgemm;
1057 
1058  if (TakeDepthWise3x3x3FastPath_()) {
1059  const T* Xdata = X.template data<T>();
1060  uint8_t* Y_uint8_data =
1061  OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1062 
1063 #ifdef _OPENMP
1064 #pragma omp parallel
1065 #endif
1066  {
1067  if (quantize_groupwise_) {
1068  depthwise_3x3x3_per_channel_quantization_pad_1(
1069  N,
1070  X.dim32(1),
1071  X.dim32(2),
1072  X.dim32(3),
1073  C,
1074  this->stride_[0],
1075  this->stride_[1],
1076  this->stride_[2],
1077  in_qparams_[INPUT].zero_point,
1078  reinterpret_cast<const uint8_t*>(Xdata),
1079  filter_zero_points_.data(),
1080  *Wq_depthwise_3x3x3_packed_,
1081  requantization_multipliers_.data(),
1082  out_qparams_.zero_point,
1083  Y_uint8_data,
1084  column_offsets_->data(),
1085  b_quantized_data_,
1086  ReluFused,
1087  dnnlowp_get_thread_num(),
1088  dnnlowp_get_num_threads());
1089  } else {
1090  depthwise_3x3x3_pad_1(
1091  N,
1092  X.dim32(1),
1093  X.dim32(2),
1094  X.dim32(3),
1095  C,
1096  this->stride_[0],
1097  this->stride_[1],
1098  this->stride_[2],
1099  in_qparams_[INPUT].zero_point,
1100  reinterpret_cast<const uint8_t*>(Xdata),
1101  FilterQuantizationParams(0).zero_point,
1102  *Wq_depthwise_3x3x3_packed_,
1103  requantization_params_[0].real_multiplier,
1104  out_qparams_.zero_point,
1105  Y_uint8_data,
1106  column_offsets_->data(),
1107  b_quantized_data_,
1108  ReluFused,
1109  dnnlowp_get_thread_num(),
1110  dnnlowp_get_num_threads());
1111  }
1112  } // omp parallel
1113 
1114  return;
1115  } else if (TakeDepthWise3x3FastPath_()) {
1116  const int H = X.dim32(1), W = X.dim32(2);
1117  const T* Xdata = X.template data<T>();
1118  uint8_t* Y_uint8_data =
1119  OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1120 
1121 #ifdef _OPENMP
1122 #pragma omp parallel
1123 #endif
1124  {
1125  if (quantize_groupwise_) {
1126  depthwise_3x3_per_channel_quantization_pad_1(
1127  N,
1128  H,
1129  W,
1130  C,
1131  stride_h(),
1132  stride_w(),
1133  in_qparams_[INPUT].zero_point,
1134  reinterpret_cast<const uint8_t*>(Xdata),
1135  filter_zero_points_.data(),
1136  *Wq_depthwise_3x3_packed_,
1137  requantization_multipliers_.data(),
1138  out_qparams_.zero_point,
1139  Y_uint8_data,
1140  column_offsets_->data(),
1141  b_quantized_data_,
1142  ReluFused,
1143  dnnlowp_get_thread_num(),
1144  dnnlowp_get_num_threads());
1145  } else {
1146  depthwise_3x3_pad_1(
1147  N,
1148  H,
1149  W,
1150  C,
1151  stride_h(),
1152  stride_w(),
1153  in_qparams_[INPUT].zero_point,
1154  reinterpret_cast<const uint8_t*>(Xdata),
1155  FilterQuantizationParams(0).zero_point,
1156  *Wq_depthwise_3x3_packed_,
1157  requantization_params_[0].real_multiplier,
1158  out_qparams_.zero_point,
1159  Y_uint8_data,
1160  column_offsets_->data(),
1161  b_quantized_data_,
1162  ReluFused,
1163  dnnlowp_get_thread_num(),
1164  dnnlowp_get_num_threads());
1165  }
1166  } // omp parallel
1167 
1168  return;
1169  } else if (TakeGConvFastPath_()) {
1170  const T* Xdata = X.template data<T>();
1171  uint8_t* Y_uint8_data =
1172  OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1173 
1174  conv_param_t<> conv_p(
1175  N,
1176  C,
1177  M,
1178  {X.dim32(1), X.dim32(2)},
1179  group_,
1180  {this->kernel_[0], this->kernel_[1]},
1181  {this->stride_[0], this->stride_[1]},
1182  {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
1183 
1184  int row_offset_size_per_thread = rowOffsetBufferSizeGConv(conv_p);
1185  row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
1186 
1187 #ifdef _OPENMP
1188 // TODO: add parallelization once fbgemmGroupwiseConv supports multi-threading
1189 // #pragma omp parallel
1190 #endif
1191  {
1192  int tid = 0; // dnnlowp_get_thread_num();
1193  int nthreads = 1; // dnnlowp_get_num_threads();
1194 
1195  DoNothing<> doNothingObj{};
1196  if (quantize_groupwise_) {
1197  ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
1198  doNothingObj,
1199  requantization_multipliers_.data(),
1200  out_qparams_.zero_point,
1201  in_qparams_[INPUT].zero_point,
1202  filter_zero_points_.data(),
1203  row_offsets_.data() + tid * row_offset_size_per_thread,
1204  column_offsets_->data(),
1205  b_quantized_data_,
1206  conv_p.OC,
1207  conv_p.G);
1208 
1209  fbgemmGroupwiseConv(
1210  conv_p,
1211  reinterpret_cast<const uint8_t*>(Xdata),
1212  in_qparams_[INPUT].zero_point,
1213  row_offsets_.data() + tid * row_offset_size_per_thread,
1214  *Wq_gconv_packed_,
1215  Y_uint8_data,
1216  Y_int32->data(),
1217  reqObj,
1218  tid,
1219  nthreads);
1220  } else {
1221  ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
1222  doNothingObj,
1223  requantization_multipliers_.data(),
1224  out_qparams_.zero_point,
1225  in_qparams_[INPUT].zero_point,
1226  filter_zero_points_.data(),
1227  row_offsets_.data() + tid * row_offset_size_per_thread,
1228  column_offsets_->data(),
1229  b_quantized_data_,
1230  conv_p.OC,
1231  conv_p.G);
1232 
1233  fbgemmGroupwiseConv(
1234  conv_p,
1235  reinterpret_cast<const uint8_t*>(Xdata),
1236  in_qparams_[INPUT].zero_point,
1237  row_offsets_.data() + tid * row_offset_size_per_thread,
1238  *Wq_gconv_packed_,
1239  Y_uint8_data,
1240  Y_int32->data(),
1241  reqObj,
1242  tid,
1243  nthreads);
1244  }
1245  } // omp parallel
1246 
1247  return;
1248  }
1249 
1250  // Normal path for non-special (e.g., no depth-wise) convolutions.
1251  int row_offset_size_per_thread = -1;
1252  int x_pack_buf_size_per_thread = -1;
1253  bool fuse_im2col =
1254  Wq_packed_ && X.template data<T>() == col_buffer_data && !IsConvGEMM_();
1255  if (Wq_packed_) {
1256  if (fuse_im2col) {
1257  row_offset_size_per_thread =
1258  PackAWithIm2Col<uint8_t>::rowOffsetBufferSize();
1259  x_pack_buf_size_per_thread = PackAWithIm2Col<uint8_t>::packedBufferSize();
1260  } else {
1261  row_offset_size_per_thread =
1262  PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
1263  x_pack_buf_size_per_thread =
1264  PackAWithRowOffset<uint8_t>::packedBufferSize();
1265  }
1266  row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
1267  X_pack_buf_.resize(dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
1268  }
1269 
1270  uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
1271 
1272  if (Wq_packed_)
1273 #ifdef _OPENMP
1274 #pragma omp parallel
1275 #endif
1276  {
1277  int tid = dnnlowp_get_thread_num();
1278 
1279  // fast path to use fbgemm
1280  if (fuse_im2col) {
1281  if (this->kernel_.size() <= 2) {
1282  conv_param_t<> conv_p(
1283  N,
1284  C,
1285  M,
1286  {X.dim32(1), this->kernel_.size() == 2 ? X.dim32(2) : 1},
1287  group_,
1288  {this->kernel_[0],
1289  this->kernel_.size() == 2 ? this->kernel_[1] : 1},
1290  {this->stride_[0],
1291  this->kernel_.size() == 2 ? this->stride_[1] : 1},
1292  {this->pads_[0],
1293  this->kernel_.size() == 2 ? this->pads_[1] : 0,
1294  this->kernel_.size() == 2 ? this->pads_[2] : this->pads_[1],
1295  this->kernel_.size() == 2 ? this->pads_[3] : 0});
1296 
1297  PackAWithIm2Col<uint8_t> packA(
1298  conv_p,
1299  reinterpret_cast<const uint8_t*>(col_buffer_data),
1300  // buffer for packed matrix
1301  X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1302  in_qparams_[INPUT].zero_point,
1303  row_offsets_.data() + tid * row_offset_size_per_thread);
1304 
1305  if (quantize_groupwise_) {
1306  DispatchFBGEMM_<
1307  PackAWithIm2Col<uint8_t>,
1308  QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1309  } else {
1310  DispatchFBGEMM_<
1311  PackAWithIm2Col<uint8_t>,
1312  QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1313  }
1314  } else {
1315  // 3D
1316  conv_param_t<3> conv_p(
1317  N,
1318  C,
1319  M,
1320  {X.dim32(1), X.dim32(2), X.dim32(3)},
1321  group_,
1322  {this->kernel_[0], this->kernel_[1], this->kernel_[2]},
1323  {this->stride_[0], this->stride_[1], this->stride_[2]},
1324  {this->pads_[0],
1325  this->pads_[1],
1326  this->pads_[2],
1327  this->pads_[3],
1328  this->pads_[4],
1329  this->pads_[5]});
1330 
1331  PackAWithIm2Col<uint8_t, int32_t, 3> packA(
1332  conv_p,
1333  reinterpret_cast<const uint8_t*>(col_buffer_data),
1334  // buffer for packed matrix
1335  X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1336  in_qparams_[INPUT].zero_point,
1337  row_offsets_.data() + tid * row_offset_size_per_thread);
1338 
1339  if (quantize_groupwise_) {
1340  DispatchFBGEMM_<
1341  PackAWithIm2Col<uint8_t, int32_t, 3>,
1342  QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1343  } else {
1344  DispatchFBGEMM_<
1345  PackAWithIm2Col<uint8_t, int32_t, 3>,
1346  QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1347  }
1348  } // 3D
1349  } else {
1350  // no im2col fusion
1351  PackAWithRowOffset<uint8_t> packA(
1352  matrix_op_t::NoTranspose,
1353  N * Y_HxW,
1354  group_ * kernel_dim,
1355  reinterpret_cast<const uint8_t*>(col_buffer_data),
1356  group_ * kernel_dim,
1357  // buffer for packed matrix
1358  X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1359  group_,
1360  row_offsets_.data() + tid * row_offset_size_per_thread);
1361 
1362  if (quantize_groupwise_) {
1363  DispatchFBGEMM_<
1364  PackAWithRowOffset<uint8_t>,
1365  QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1366  } else {
1367  DispatchFBGEMM_<
1368  PackAWithRowOffset<uint8_t>,
1369  QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1370  }
1371  } // no im2col fusion
1372  } else {
1373  for (int group_id = 0; group_id < group_; ++group_id) {
1374  // Wq_packed_.empty()
1375  conv_nhwc_ref_(
1376  group_id,
1377  group_,
1378  0,
1379  N * Y_HxW,
1380  M,
1381  kernel_dim,
1382  col_buffer_data,
1383  W_quantized_.data(),
1384  Y_int32->data());
1385  }
1386  }
1387 }
1388 
1389 template <typename T, bool ReluFused>
1391  CAFFE_ENFORCE_LE(
1392  this->kernel_.size(),
1393  3,
1394  "Only 1-3d convolutions are supported for NHWC storage type");
1395 
1396  using namespace dnnlowp;
1397 
1398 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1399  chrono::time_point<chrono::system_clock> t_very_begin, t_begin, t_end;
1400  /*if (VLOG_IS_ON(3))*/ {
1401  t_begin = chrono::system_clock::now();
1402  t_very_begin = t_begin;
1403  }
1404 #endif
1405 
1406  // Get quantization parameters
1407  if (!GetQuantizationParameters_()) {
1408  return false;
1409  }
1410 
1411 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1412  /*if (VLOG_IS_ON(3))*/ {
1413  t_end = chrono::system_clock::now();
1414  double dt = chrono::duration<double>(t_end - t_begin).count();
1415  LOG(INFO) << "this=" << this << " get_quant_params: " << dt * 1e3 << " ms";
1416  }
1417 #endif
1418 
1419  const Tensor& X = InputTensorCPU_(INPUT);
1420  auto& filter = InputTensorCPU_(FILTER);
1421  const int C = X.dim32(X.dim() - 1);
1422  const int G = group_;
1423  CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
1424  const int M = filter.dim32(0);
1425  CAFFE_ENFORCE_EQ(
1426  C,
1427  filter.dim32(filter.dim() - 1) * G,
1428  "Convolution op: input channels does not match: # of input channels ",
1429  C,
1430  " is not equal to kernel channels * group: ",
1431  filter.dim32(filter.dim() - 1),
1432  "*",
1433  G);
1434  CAFFE_ENFORCE_EQ(
1435  M % G, 0, "The number of output channels is not divisible by group.");
1436 
1437  auto sizes = ConvPoolOpBase<CPUContext>::GetOutputSize(X, filter.dim32(0));
1438  Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
1439 
1440  // The col buffer is stored in HWC order as well - kernel_dim, and the height
1441  // and width.
1442 
1443 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1444  /*if (VLOG_IS_ON(3)) */ { t_begin = chrono::system_clock::now(); }
1445 #endif
1446 
1447  bool no_im2col = NoIm2ColNHWC_();
1448  auto f = [&](Tensor* col_buffer, vector<int32_t>* Y_int32) {
1449  if (!TakeDepthWise3x3FastPath_() && !TakeDepthWise3x3x3FastPath_()) {
1450  Y_int32->resize(Y->numel());
1451  }
1452 
1453  // Im2col, followed by gemm.
1454  const T* Xdata = X.template data<T>();
1455  const T* col_buffer_data = no_im2col ? Xdata : Im2ColNHWC_(col_buffer);
1456 
1457 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1458  /*if (VLOG_IS_ON(3)) */ {
1459  t_end = chrono::system_clock::now();
1460  double dt = chrono::duration<double>(t_end - t_begin).count();
1461  LOG(INFO) << "this=" << this << " im2col: " << dt * 1e3 << " ms";
1462  t_begin = chrono::system_clock::now();
1463  }
1464 #endif
1465 
1466 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1467  /*if (VLOG_IS_ON(3)) */ {
1468  t_end = chrono::system_clock::now();
1469  double dt = chrono::duration<double>(t_end - t_begin).count();
1470  LOG(INFO) << "this=" << this << " quantize col_buf: " << dt * 1e3
1471  << " ms";
1472  t_begin = chrono::system_clock::now();
1473  }
1474 #endif
1475 
1476  ConvNHWCCore_(col_buffer_data, Y_int32);
1477 
1478 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1479  /*if (VLOG_IS_ON(3)) */ {
1480  t_end = chrono::system_clock::now();
1481  double dt = chrono::duration<double>(t_end - t_begin).count();
1482  LOG(INFO) << "this=" << this << " GEMM: " << dt * 1e3 << " ms";
1483  t_begin = chrono::system_clock::now();
1484  }
1485 #endif
1486 
1487  if (Wq_packed_ || Wq_depthwise_3x3_packed_ || Wq_depthwise_3x3x3_packed_ ||
1488  Wq_gconv_packed_) {
1489  // In fast path with fbgemm except when
1490  // rescaling quantized numbers should've been already done.
1491  PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
1492  } else {
1493  RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
1494  }
1495  }; // f
1496 
1497  this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
1498 
1499 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
1500  /*if (VLOG_IS_ON(3)) */ {
1501  const int N = X.dim32(0);
1502  // The dimension of each kernel
1503  const int kernel_dim = KernelDim_();
1504  // The output image size is the spatial size of the output.
1505  const int Y_HxW = this->GetDimsSize(*Y);
1506 
1507  t_end = chrono::system_clock::now();
1508  double dt = chrono::duration<double>(t_end - t_begin).count();
1509  LOG(INFO) << "this=" << this << " prologue: " << dt * 1e3 << " ms";
1510  t_begin = chrono::system_clock::now();
1511 
1512  t_end = chrono::system_clock::now();
1513  const int M = filter.dim32(0);
1514  double ops = 2. * N * Y_HxW * M * kernel_dim;
1515  dt = chrono::duration<double>(t_end - t_very_begin).count();
1516  double gops = ops / dt / 1e9;
1517  LOG(INFO) << "this=" << this << " " << this->debug_def().type()
1518  << " output=" << this->debug_def().output(0) << " " << N * Y_HxW
1519  << "x" << M << "x" << kernel_dim << " G=" << group_
1520  << " C/G=" << C / group_ << " K/G=" << M / group_
1521  << " R=" << kernel_h() << " S=" << kernel_w() << " : " << dt * 1e3
1522  << " ms " << gops << " gops";
1523  }
1524 #endif
1525 
1526  MeasureQuantizationError_();
1527 
1528  return true;
1529 }
1530 
1531 template class ConvDNNLowPOp<uint8_t, false>;
1532 template class ConvDNNLowPOp<uint8_t, true>;
1533 
1534 template class ConvDNNLowPOp<uint16_t, false>;
1535 template class ConvDNNLowPOp<uint16_t, true>;
1536 
1537 OPERATOR_SCHEMA(ConvRelu).NumInputs(2, 3).NumOutputs(1).TensorInferenceFunction(
1539 
1540 REGISTER_CPU_OPERATOR_WITH_ENGINE(Conv, DNNLOWP, ConvDNNLowPOp<uint8_t, false>);
1541 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1542  ConvRelu,
1543  DNNLOWP,
1545 
1546 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1547  Int8Conv,
1548  DNNLOWP,
1550 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1551  Int8ConvRelu,
1552  DNNLOWP,
1554 
1555 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1556  Conv,
1557  DNNLOWP_16,
1559 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1560  ConvRelu,
1561  DNNLOWP_16,
1563 
1564 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
Definition: OpClasses.h:13
Definition: static.cpp:64
void Get1DPartitionOf2D(int m, int n, int nthreads, int tid, int *m_begin, int *m_end, int *n_begin, int *n_end, int n_align)
1D-partition m x n 2D work.