1 #include "conv_dnnlowp_op.h" 4 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 12 #include "caffe2/core/tensor_int8.h" 13 #include "caffe2/utils/cpuid.h" 15 #include <fbgemm/src/RefImplementations.h> 17 #include "dnnlowp_op.h" 18 #include "dnnlowp_partition.h" 19 #include "fbgemm_pack_op.h" 20 #include "im2col_dnnlowp.h" 24 caffe2_dnnlowp_shared_int32_buffer,
26 "Share intermediate int32 buffer across DNNLOWP Conv ops");
29 caffe2_dnnlowp_dump_tensors,
31 "Dump quantized input and weight tensors used in Conv and FC operators " 32 "during the first iteration");
34 C10_DECLARE_bool(caffe2_dnnlowp_force_slow_path);
40 template <
typename T,
bool ReluFused>
41 ConvDNNLowPOp<T, ReluFused>::ConvDNNLowPOp(
42 const OperatorDef& operator_def,
44 : BaseType(operator_def, ws),
45 column_offsets_(make_shared<vector<int32_t>>()),
46 b_quantized_(make_shared<vector<int32_t>>()) {
47 in_qparams_.resize(1);
51 if (FLAGS_caffe2_force_shared_col_buffer || shared_buffer_) {
52 createSharedBuffer<CPUContext>(ws_);
55 if (FLAGS_caffe2_dnnlowp_shared_int32_buffer) {
56 this->CreateSharedInt32Buffer_();
60 this->
template GetSingleArgument<bool>(
"quantize_groupwise",
false);
63 template <
typename T,
bool ReluFused>
64 ConvDNNLowPOp<T, ReluFused>::~ConvDNNLowPOp() {}
66 template <
typename T,
bool ReluFused>
67 dnnlowp::TensorQuantizationParams&
68 ConvDNNLowPOp<T, ReluFused>::FilterQuantizationParams(
int group_id) {
69 return filter_qparams_[quantize_groupwise_ ? group_id : 0];
72 template <
typename T,
bool ReluFused>
73 dnnlowp::RequantizationParams&
74 ConvDNNLowPOp<T, ReluFused>::RequantizationParams(
int group_id) {
75 return requantization_params_[quantize_groupwise_ ? group_id : 0];
80 template <
typename T,
bool ReluFused>
81 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3FastPath_() {
82 const Tensor& X = InputTensorCPU_(INPUT);
83 return this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
84 !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
85 this->kernel_.size() == 2 && kernel_h() == 3 && kernel_w() == 3 &&
86 stride_h() == stride_w() && (stride_h() == 1 || stride_h() == 2) &&
87 dilation_h() == 1 && dilation_w() == 1 && pad_t() == 1 && pad_b() == 1 &&
88 pad_l() == 1 && pad_r() == 1 && GetCpuId().avx2();
93 template <
typename T,
bool ReluFused>
94 bool ConvDNNLowPOp<T, ReluFused>::TakeDepthWise3x3x3FastPath_() {
95 const Tensor& X = InputTensorCPU_(INPUT);
96 bool ret = this->order_ == StorageOrder::NHWC && is_same<T, uint8_t>::value &&
97 !Acc16() && group_ == X.dim32(X.dim() - 1) && group_ % 8 == 0 &&
98 this->kernel_.size() == 3 && this->kernel_[0] == 3 &&
99 this->kernel_[1] == 3 && this->kernel_[2] == 3 &&
100 this->stride_[0] == this->stride_[1] &&
101 this->stride_[0] == this->stride_[2] &&
102 (this->stride_[0] == 1 || this->stride_[0] == 2) &&
103 this->dilation_[0] == 1 && this->dilation_[1] == 1 &&
104 this->dilation_[2] == 1 &&
106 this->pads_.begin(), this->pads_.end(), 1, multiplies<int>()) == 1 &&
111 template <
typename T,
bool ReluFused>
112 bool ConvDNNLowPOp<T, ReluFused>::TakeGConvFastPath_() {
113 const Tensor& X = InputTensorCPU_(INPUT);
114 if (this->order_ != StorageOrder::NHWC || !is_same<T, uint8_t>::value ||
115 !X.template IsType<T>() || this->kernel_.size() != 2) {
119 auto& filter = InputTensorCPU_(FILTER);
120 const int N = X.dim32(0),
C = X.dim32(X.dim() - 1);
121 const int M = filter.dim32(0);
122 fbgemm::conv_param_t<> conv_p(
126 {X.dim32(1), X.dim32(2)},
128 {this->kernel_[0], this->kernel_[1]},
129 {this->stride_[0], this->stride_[1]},
130 {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
132 return fbgemm::fbgemmOptimizedGConv(conv_p);
135 template <
typename T,
bool ReluFused>
136 int ConvDNNLowPOp<T, ReluFused>::KernelDim_() {
138 const Tensor& X = InputTensorCPU_(INPUT);
139 const auto& filter = InputTensorCPU_(FILTER);
143 if (ConvPoolOpBase<CPUContext>::order_ == StorageOrder::NCHW) {
147 C = X.dim32(X.dim() - 1);
151 int kernel_dims_size = 1;
152 for (
int i = 0; i < this->kernel_.size(); ++i) {
153 CAFFE_ENFORCE_EQ(filter.dim32(i + filter_offset), kernel_[i]);
154 kernel_dims_size *= kernel_[i];
156 kernel_dim = C / group_ * kernel_dims_size;
161 template <
typename T,
bool ReluFused>
164 this->kernel_.begin(),
167 multiplies<int>()) == 1 &&
169 this->stride_.begin(), this->stride_.end(), 1, multiplies<int>()) ==
172 this->dilation_.begin(),
173 this->dilation_.end(),
175 multiplies<int>()) == 1 &&
176 accumulate(this->pads_.begin(), this->pads_.end(), 0) == 0;
179 template <
typename T,
bool ReluFused>
181 if (TakeDepthWise3x3FastPath_() || TakeDepthWise3x3x3FastPath_() ||
182 TakeGConvFastPath_()) {
188 this->dilation_.begin(),
189 this->dilation_.end(),
191 multiplies<int>()) == 1) {
199 template <
typename T,
bool ReluFused>
201 const auto& filter = InputTensorCPU_(FILTER);
202 int kernel_dim = KernelDim_();
203 int M = filter.dim32(0);
206 vector<int>& offsets =
211 if (offsets.empty()) {
212 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
213 const auto& packed_filter =
214 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
215 column_offsets_ = packed_filter.column_offsets;
217 ComputeColumnOffsets<T_signed>(
218 kernel_dim, M, W_quantized_.data(), filter_qparams_, offsets);
223 template <
typename T,
bool ReluFused>
227 const auto& filter = InputTensorCPU_(FILTER);
228 int M = filter.dim32(0);
230 bool has_packed_bias =
231 this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER) &&
232 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER).bias.get();
233 bool has_bias = InputSize() == 3 || has_packed_bias;
237 (!b_quantized_data_ ||
238 in_qparams_[INPUT].scale != in_qparams_scale_old_)) {
239 if (has_packed_bias) {
240 const auto& packed_filter =
241 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
242 b_quantized_ = packed_filter.bias;
243 b_quantized_data_ = b_quantized_->data();
245 const auto& bias = InputTensorCPU_(BIAS);
246 if (this->
template InputIsType<int8::Int8TensorCPU>(BIAS)) {
247 TensorQuantizationParams bias_qparams;
249 this->
template Input<int8::Int8TensorCPU>(BIAS).scale;
250 bias_qparams.zero_point =
251 this->
template Input<int8::Int8TensorCPU>(BIAS).zero_point;
255 in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale),
257 CAFFE_ENFORCE_EQ(bias_qparams.zero_point, 0);
258 b_quantized_data_ = bias.template data<int32_t>();
260 const float* b_data = bias.template data<float>();
261 b_quantized_->resize(bias.numel());
262 for (
int g = 0; g < filter_qparams_.size(); ++g) {
263 int i_begin = g * (M / filter_qparams_.size());
264 int i_end = i_begin + (M / filter_qparams_.size());
265 for (
int i = i_begin; i < i_end; ++i) {
266 (*b_quantized_)[i] = fbgemm::Quantize<int32_t>(
269 in_qparams_[INPUT].scale * FilterQuantizationParams(g).scale,
274 b_quantized_data_ = b_quantized_->data();
276 in_qparams_scale_old_ = in_qparams_[INPUT].scale;
279 CAFFE_ENFORCE(b_quantized_data_);
283 template <
typename T,
bool ReluFused>
288 int kernel_dim = KernelDim_();
289 const auto& filter = InputTensorCPU_(FILTER);
290 int M = filter.dim32(0);
293 !Acc16() && is_same<T, uint8_t>::value && GetCpuId().avx2() &&
294 !FLAGS_caffe2_dnnlowp_force_slow_path;
296 bool depthwise_3x3_fast_path =
false, depthwise_3x3x3_fast_path =
false,
297 gconv_fast_path =
false;
298 if (TakeDepthWise3x3FastPath_()) {
299 depthwise_3x3_fast_path =
true;
301 }
else if (TakeDepthWise3x3x3FastPath_()) {
302 depthwise_3x3x3_fast_path =
true;
304 }
else if (TakeGConvFastPath_()) {
305 gconv_fast_path =
true;
309 if ((depthwise_3x3_fast_path && !Wq_depthwise_3x3_packed_) ||
310 (depthwise_3x3x3_fast_path && !Wq_depthwise_3x3x3_packed_) ||
311 (gconv_fast_path && !Wq_gconv_packed_) || (packW && !Wq_packed_) ||
312 (!packW && W_quantized_.empty())) {
313 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
317 "Pre-packed weight only works with NHWC layout");
319 const auto& packed_filter =
320 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
321 filter_qparams_ = packed_filter.qparams;
323 filter_qparams_.resize(quantize_groupwise_ ? group_ : 1);
332 if (this->
template InputIsType<int8::Int8TensorCPU>(FILTER) &&
333 quantize_groupwise_) {
334 static int log_occurences = 0;
335 if (log_occurences < 32) {
337 LOG(WARNING) <<
"Cannot do group-wise quantization for " 338 "pre-quantized weight " 339 << this->debug_def().input(FILTER);
344 filter_zero_points_.resize(filter_qparams_.size());
345 requantization_params_.resize(filter_qparams_.size());
346 requantization_multipliers_.resize(filter_qparams_.size());
347 for (
int i = 0; i < filter_qparams_.size(); ++i) {
348 filter_zero_points_[i] = filter_qparams_[i].zero_point;
351 if (depthwise_3x3_fast_path) {
352 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
353 const auto& packed_filter =
354 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
355 Wq_depthwise_3x3_packed_ = packed_filter.W_depthwise_3x3;
357 Wq_depthwise_3x3_packed_.reset(
new fbgemm::Packed3x3ConvMatrix(
358 group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
360 }
else if (depthwise_3x3x3_fast_path) {
361 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
362 const auto& packed_filter =
363 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
364 Wq_depthwise_3x3x3_packed_ = packed_filter.W_depthwise_3x3x3;
366 Wq_depthwise_3x3x3_packed_.reset(
new fbgemm::Packed3x3x3ConvMatrix(
367 group_, reinterpret_cast<const int8_t*>(W_quantized_.data())));
369 }
else if (gconv_fast_path) {
370 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
371 const auto& packed_filter =
372 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
373 Wq_gconv_packed_ = packed_filter.W_gconv;
375 const Tensor& X = InputTensorCPU_(INPUT);
376 const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
378 fbgemm::conv_param_t<> conv_p(
382 {X.dim32(1), X.dim32(2)},
384 {this->kernel_[0], this->kernel_[1]},
385 {this->stride_[0], this->stride_[1]},
386 {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
388 Wq_gconv_packed_.reset(
new fbgemm::PackWeightMatrixForGConv<int8_t>(
389 fbgemm::matrix_op_t::Transpose,
391 reinterpret_cast<const int8_t*>(W_quantized_.data())));
394 if (this->
template InputIsType<Int8ConvDNNLowPPackedWeightBlob>(FILTER)) {
395 const auto& packed_filter =
396 this->
template Input<Int8ConvDNNLowPPackedWeightBlob>(FILTER);
397 Wq_packed_ = packed_filter.W;
400 Wq_packed_.reset(
new fbgemm::PackBMatrix<int8_t>(
401 fbgemm::matrix_op_t::Transpose,
404 reinterpret_cast<const int8_t*>(W_quantized_.data()),
412 reason =
"fbgemm only supports NHWC layout";
413 }
else if (!is_same<T, uint8_t>::value) {
414 reason =
"fbgemm only supports 8-bit integers";
415 }
else if (!GetCpuId().avx2()) {
416 reason =
"fbgemm only supports AVX2+";
417 }
else if (Acc16()) {
419 }
else if (FLAGS_caffe2_dnnlowp_force_slow_path) {
420 reason =
"slow path enforced";
424 if (!reason.empty()) {
425 static int log_occurences = 0;
426 if (log_occurences < 32) {
428 LOG(WARNING) <<
"Conv with weight " << this->debug_def().input(FILTER)
429 <<
" falls back to slow path because " << reason;
439 template <
typename T,
bool ReluFused>
443 if (!this->arguments_parsed_) {
444 bool dequantize_output;
445 ParseDNNLowPOperatorArguments(
446 this, &dequantize_output, &measure_quantization_error_, &followed_by_);
450 "Conv DNNLOWP operators don't support dequantize_output");
456 followed_by_ =
"Relu";
457 AdjustOutputTensorQuantizationParamsWithFollowedBy(
this, followed_by_);
459 this->arguments_parsed_ =
true;
464 GetInputTensorQuantizationParamsOf(
this, INPUT, qfactory_.get());
467 PreComputeRowColumnOffsets_();
468 if (Wq_packed_ && !FLAGS_caffe2_dnnlowp_dump_tensors) {
470 vector<T_signed>().swap(W_quantized_);
475 bool fp32_executed =
false;
476 if (HasStaticQuantization(
this)) {
477 out_qparams_ = GetStaticQuantizationParamsOf(
this, 0);
481 Fp32Op_()->DequantizeInput();
482 Fp32Op_()->Get()->RunOnDevice();
483 out_qparams_ = Fp32Op_()->GetOutputQuantizationParams(qfactory_.get());
484 fp32_executed =
true;
487 for (
int g = 0; g < filter_qparams_.size(); ++g) {
488 float real_multiplier = in_qparams_[INPUT].scale *
489 FilterQuantizationParams(g).scale / out_qparams_.scale;
490 requantization_params_[g] = qfactory_->ChooseRequantizationMultiplier(
491 real_multiplier, out_qparams_);
492 requantization_multipliers_[g] = requantization_params_[g].real_multiplier;
495 if (measure_quantization_error_ && Fp32Op_() && !fp32_executed) {
497 Fp32Op_()->DequantizeInput();
498 Fp32Op_()->Get()->RunOnDevice();
504 template <
typename T,
bool ReluFused>
506 const T* col_buffer_data,
511 auto& filter = InputTensorCPU_(FILTER);
512 const int M = filter.dim32(0);
513 int kernel_dim = KernelDim_();
515 Tensor* Y = OutputTensorCPU_(0);
516 const int Y_HxW = this->GetDimsSize(*Y);
520 int tid = dnnlowp_get_thread_num();
521 int32_t* column_offsets = column_offsets_->data() + tid * Y_HxW;
523 const dnnlowp::TensorQuantizationParams& filter_qparams =
524 FilterQuantizationParams(group_id);
525 for (
int j = 0; j < Y_HxW; ++j) {
527 for (
int k = 0; k < kernel_dim; ++k) {
528 sum += col_buffer_data[k * Y_HxW + j];
530 column_offsets[j] = sum * filter_qparams.zero_point;
533 for (
int i = 0; i < M / group_; ++i) {
534 int32_t row_offset = row_offsets_[i_offset + i];
535 row_offset *= -in_qparams_[INPUT].zero_point;
536 if (b_quantized_data_) {
537 row_offset += b_quantized_data_[i_offset + i];
539 for (
int j = 0; j < Y_HxW; ++j) {
540 int32_t raw = Y_int32[i * Y_HxW + j] + row_offset - column_offsets[j];
542 raw = std::max(0, raw);
544 Y_data[i * Y_HxW + j] =
545 fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
550 template <
typename T,
bool ReluFused>
552 VLOG(2) <<
"Running DNNLOWP Conv";
557 if (!GetQuantizationParameters_()) {
561 const Tensor& X = InputTensorCPU_(INPUT);
562 auto& filter = InputTensorCPU_(FILTER);
563 const int N = X.dim32(0), C = X.dim32(1);
564 CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
565 const int M = filter.dim32(0);
568 filter.dim32(1) * group_,
569 "Convolution op: input channels does not match: # of input channels ",
571 " is not equal to kernel channels * group:",
578 "The number of output channels is not divisible by group.");
581 Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
583 const vector<int> input_dims = GetDims(X);
584 const vector<int> output_dims = GetDims(*Y);
585 const int X_HxW = this->GetDimsSize(X);
586 const int Y_HxW = this->GetDimsSize(*Y);
589 const int kernel_dim = KernelDim_();
591 vector<int> img_shape;
592 img_shape.assign(X.sizes().begin() + 1, X.sizes().end());
594 vector<int> buffer_shape;
595 buffer_shape.push_back(kernel_dim);
597 buffer_shape.end(), output_dims.begin(), output_dims.end());
598 buffer_shape.insert(buffer_shape.begin(), dnnlowp_get_max_threads());
600 if (this->kernel_.size() != 2) {
601 SetDeviceTensor(img_shape, &img_shape_device_);
602 SetDeviceTensor(buffer_shape, &col_buffer_shape_device_);
605 const int col_buffer_size = kernel_dim * Y_HxW;
609 const int input_offset = C / group_ * X_HxW;
613 const T* Xdata = X.template data<T>();
616 T* Y_data_T = Y->template mutable_data<T>();
617 column_offsets_->resize(Y_HxW * dnnlowp_get_max_threads());
619 auto f = [&](
Tensor* col_buffer, vector<int32_t>* Y_int32) {
620 col_buffer->Resize(buffer_shape);
621 vector<int> buffer_shape_per_thread(
622 buffer_shape.begin() + 1, buffer_shape.end());
623 T* col_buffer_data = col_buffer->template mutable_data<T>();
625 Y_int32->resize(M * Y_HxW * dnnlowp_get_max_threads());
629 #pragma omp parallel for 631 for (
int image_id = 0; image_id < N; ++image_id) {
632 int tid = dnnlowp_get_thread_num();
633 for (
int group_id = 0; group_id < group_; ++group_id) {
634 if (this->kernel_.size() == 2) {
649 Xdata + (group_ * image_id + group_id) * input_offset,
650 col_buffer_data + tid * col_buffer_size,
652 in_qparams_[INPUT].zero_point);
654 math::Im2ColNdNCHW<T>(
655 this->kernel_.size(),
659 buffer_shape_per_thread.data(),
660 this->kernel_.data(),
661 this->stride_.data(),
662 this->dilation_.data(),
664 Xdata + (group_ * image_id + group_id) * input_offset,
665 col_buffer_data + tid * col_buffer_size,
667 in_qparams_[INPUT].zero_point);
671 T* col_buffer_private = col_buffer_data + tid * col_buffer_size;
673 int32_t* Y_int32_temp =
674 Y_int32->data() + ((M / group_) * group_id + M * tid) * Y_HxW;
675 T_signed* W_quantized_group =
676 W_quantized_.data() + (M / group_) * group_id * kernel_dim;
678 for (
int i = 0; i < M / group_; ++i) {
679 for (
int j = 0; j < Y_HxW; ++j) {
681 for (
int k = 0; k < kernel_dim; ++k) {
682 int w = W_quantized_group[i * kernel_dim + k];
683 int x = col_buffer_private[k * Y_HxW + j];
686 Y_int32_temp[i * Y_HxW + j] = sum;
690 RunOnDeviceEpilogueNCHW_(
693 Y_data_T + (M * image_id + M / group_ * group_id) * Y_HxW,
694 M / group_ * group_id,
699 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
700 MeasureQuantizationError_();
703 this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
708 template <
typename T,
bool ReluFused>
710 const T* col_buffer_data,
712 const Tensor& X = InputTensorCPU_(INPUT);
713 auto& filter = InputTensorCPU_(FILTER);
714 Tensor* Y = OutputTensorCPU_(0);
715 const int N = X.dim32(0);
716 const int M = filter.dim32(0);
717 int kernel_dim = KernelDim_();
718 const int Y_HxW = this->GetDimsSize(*Y);
723 int32_t A_zero_point = in_qparams_[INPUT].zero_point;
725 if (!dnnlowp::HasStaticQuantization(
this)) {
726 if (quantize_groupwise_) {
727 static int log_occurences = 0;
728 if (log_occurences < 32) {
730 LOG(WARNING) <<
"Cannot do group-wise quantization without " 731 "static quantization of activations for " 732 << this->debug_def().output(0);
736 int32_t Y_min = numeric_limits<int32_t>::max();
737 int32_t Y_max = numeric_limits<int32_t>::min();
740 #pragma omp parallel for reduction(min : Y_min), reduction(max : Y_max) 742 for (
int i = 0; i < N * Y_HxW; ++i) {
743 for (
int group_id = 0; group_id < group_; ++group_id) {
744 int32_t row_offset = 0;
745 for (
int k = 0; k < kernel_dim; ++k) {
747 col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
749 row_offset *= FilterQuantizationParams(0).zero_point;
751 for (
int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
753 int32_t raw = Y_int32[i * M + j] -
754 A_zero_point * (*column_offsets_)[j] - row_offset;
755 if (b_quantized_data_) {
756 raw += b_quantized_data_[j];
758 Y_min = std::min(Y_min, raw);
759 Y_max = std::max(Y_max, raw);
765 Y_min = std::max(0, Y_min);
766 Y_max = std::max(0, Y_max);
770 in_qparams_[INPUT].scale * FilterQuantizationParams(0).scale;
772 qfactory_->ChooseQuantizationParams(Y_scale * Y_min, Y_scale * Y_max);
774 float real_multiplier = Y_scale / out_qparams_.scale;
775 requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
776 real_multiplier, out_qparams_);
777 requantization_multipliers_[0] = requantization_params_[0].real_multiplier;
780 int32_t C_zero_point = out_qparams_.zero_point;
782 T* Ydata = Y->template mutable_data<T>();
785 if (is_same<T, uint8_t>::value && GetCpuId().avx2()) {
787 #pragma omp parallel for 789 for (
int i = 0; i < N * Y_HxW; ++i) {
790 for (
int group_id = 0; group_id < group_; ++group_id) {
792 row_offsets_u8acc32_ref(
796 reinterpret_cast<const uint8_t*>(
797 col_buffer_data + (i * group_ + group_id) * kernel_dim),
800 int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
801 float C_multiplier = RequantizationParams(group_id).real_multiplier;
803 requantize_u8acc32_ref(
807 Y_int32 + i * M + group_id * (M / group_),
808 reinterpret_cast<uint8_t*>(Ydata + i * M + group_id * (M / group_)),
814 column_offsets_->data() + group_id * (M / group_),
815 b_quantized_data_ ? b_quantized_data_ + group_id * (M / group_)
823 #pragma omp parallel for 825 for (
int i = 0; i < N * Y_HxW; ++i) {
826 for (
int group_id = 0; group_id < group_; ++group_id) {
827 int32_t B_zero_point = FilterQuantizationParams(group_id).zero_point;
828 int32_t row_offset = 0;
829 for (
int k = 0; k < kernel_dim; ++k) {
831 col_buffer_data[(i * group_ + group_id) * kernel_dim + k];
833 row_offset *= B_zero_point;
835 for (
int j = group_id * (M / group_); j < (group_id + 1) * (M / group_);
837 int32_t raw = Y_int32[i * M + j] -
838 A_zero_point * (*column_offsets_)[j] - row_offset;
839 if (b_quantized_data_) {
840 raw += b_quantized_data_[j];
844 fbgemm::Requantize<T>(raw, RequantizationParams(group_id));
847 std::max<int32_t>(C_zero_point, Ydata[i * M + j]);
854 dnnlowp::PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
857 template <
typename T,
bool ReluFused>
882 template <
typename T,
bool ReluFused>
884 const Tensor& X = InputTensorCPU_(INPUT);
885 Tensor* Y = OutputTensorCPU_(0);
887 const int N = X.dim32(0), C = X.dim32(ndim - 1);
889 const int kernel_dim = KernelDim_();
892 const int X_HxW = this->GetDimsSize(X);
893 const int input_offset = X_HxW * C;
894 const int Y_HxW = this->GetDimsSize(*Y);
896 const T* Xdata = X.template data<T>();
898 vector<int> buffer_shape(ndim);
899 for (
auto i = 0; i < ndim - 1; ++i) {
900 buffer_shape[i] = Y->dim32(i);
902 buffer_shape[ndim - 1] = kernel_dim * group_;
904 col_buffer->Resize(buffer_shape);
906 T* col_buffer_data = col_buffer->template mutable_data<T>();
909 #pragma omp parallel for if (N > 1) 911 for (
int image_id = 0; image_id < N; ++image_id) {
912 if (this->kernel_.size() <= 2) {
916 this->kernel_.size() == 2 ? X.dim32(2) : 1,
918 this->kernel_.size() == 2 ? kernel_w() : 1,
920 this->kernel_.size() == 2 ? dilation_w() : 1,
922 this->kernel_.size() == 2 ? pad_l() : 0,
923 this->kernel_.size() == 2 ? pad_b() : pad_l(),
924 this->kernel_.size() == 2 ? pad_r() : 0,
926 this->kernel_.size() == 2 ? stride_w() : 1,
927 Xdata + image_id * input_offset,
928 col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
931 in_qparams_[INPUT].zero_point);
933 math::Im2Col3DNHWC<T>(
953 Xdata + image_id * input_offset,
954 col_buffer_data + image_id * group_ * kernel_dim * Y_HxW,
957 in_qparams_[INPUT].zero_point);
961 return col_buffer->template data<T>();
964 template <
typename T,
typename T_
signed>
965 static void conv_nhwc_ref_(
975 for (
int i = i_begin; i < i_end; ++i) {
976 for (
int j = group_id * (M / num_groups);
977 j < (group_id + 1) * (M / num_groups);
980 for (
int k = 0; k < kernel_dim; ++k) {
981 int w = W[j * kernel_dim + k];
982 int x = col_buffer[(i * num_groups + group_id) * kernel_dim + k];
990 template <
typename T,
bool ReluFused>
991 template <
typename PackAMatrix, fbgemm::QuantizationGranularity Q_GRAN>
994 vector<int32_t>* Y_int32,
995 uint8_t* Y_uint8_data) {
997 auto& filter = InputTensorCPU_(FILTER);
998 const int M = filter.dim32(0);
1000 int nthreads = dnnlowp_get_num_threads();
1001 int tid = dnnlowp_get_thread_num();
1004 DoNothing<> doNothingObj{};
1005 ReQuantizeOutput<ReluFused, Q_GRAN> outputProcObj(
1007 requantization_multipliers_.data(),
1008 out_qparams_.zero_point,
1009 in_qparams_[INPUT].zero_point,
1010 filter_zero_points_.data(),
1011 packA.getRowOffsetBuffer(),
1012 column_offsets_->data(),
1028 template <
typename T,
bool ReluFused>
1030 const T* col_buffer_data,
1031 vector<int32_t>* Y_int32) {
1032 const Tensor& X = InputTensorCPU_(INPUT);
1033 auto& filter = InputTensorCPU_(FILTER);
1034 Tensor* Y = OutputTensorCPU_(0);
1035 const int N = X.dim32(0), C = X.dim32(X.dim() - 1);
1036 const int M = filter.dim32(0);
1037 const int kernel_dim = KernelDim_();
1038 const int Y_HxW = this->GetDimsSize(*Y);
1040 if (FLAGS_caffe2_dnnlowp_dump_tensors) {
1042 StoreMatrixInMatrixMarketFormat(
1046 this->debug_def().input(INPUT));
1049 StoreMatrixInMatrixMarketFormat(
1052 W_quantized_.data(),
1053 this->debug_def().input(FILTER));
1058 if (TakeDepthWise3x3x3FastPath_()) {
1059 const T* Xdata = X.template data<T>();
1060 uint8_t* Y_uint8_data =
1061 OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1064 #pragma omp parallel 1067 if (quantize_groupwise_) {
1068 depthwise_3x3x3_per_channel_quantization_pad_1(
1077 in_qparams_[INPUT].zero_point,
1078 reinterpret_cast<const uint8_t*
>(Xdata),
1079 filter_zero_points_.data(),
1080 *Wq_depthwise_3x3x3_packed_,
1081 requantization_multipliers_.data(),
1082 out_qparams_.zero_point,
1084 column_offsets_->data(),
1087 dnnlowp_get_thread_num(),
1088 dnnlowp_get_num_threads());
1090 depthwise_3x3x3_pad_1(
1099 in_qparams_[INPUT].zero_point,
1100 reinterpret_cast<const uint8_t*
>(Xdata),
1101 FilterQuantizationParams(0).zero_point,
1102 *Wq_depthwise_3x3x3_packed_,
1103 requantization_params_[0].real_multiplier,
1104 out_qparams_.zero_point,
1106 column_offsets_->data(),
1109 dnnlowp_get_thread_num(),
1110 dnnlowp_get_num_threads());
1115 }
else if (TakeDepthWise3x3FastPath_()) {
1116 const int H = X.dim32(1), W = X.dim32(2);
1117 const T* Xdata = X.template data<T>();
1118 uint8_t* Y_uint8_data =
1119 OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1122 #pragma omp parallel 1125 if (quantize_groupwise_) {
1126 depthwise_3x3_per_channel_quantization_pad_1(
1133 in_qparams_[INPUT].zero_point,
1134 reinterpret_cast<const uint8_t*>(Xdata),
1135 filter_zero_points_.data(),
1136 *Wq_depthwise_3x3_packed_,
1137 requantization_multipliers_.data(),
1138 out_qparams_.zero_point,
1140 column_offsets_->data(),
1143 dnnlowp_get_thread_num(),
1144 dnnlowp_get_num_threads());
1146 depthwise_3x3_pad_1(
1153 in_qparams_[INPUT].zero_point,
1154 reinterpret_cast<const uint8_t*>(Xdata),
1155 FilterQuantizationParams(0).zero_point,
1156 *Wq_depthwise_3x3_packed_,
1157 requantization_params_[0].real_multiplier,
1158 out_qparams_.zero_point,
1160 column_offsets_->data(),
1163 dnnlowp_get_thread_num(),
1164 dnnlowp_get_num_threads());
1169 }
else if (TakeGConvFastPath_()) {
1170 const T* Xdata = X.template data<T>();
1171 uint8_t* Y_uint8_data =
1172 OutputTensorCPU_(0)->template mutable_data<uint8_t>();
1174 conv_param_t<> conv_p(
1178 {X.dim32(1), X.dim32(2)},
1180 {this->kernel_[0], this->kernel_[1]},
1181 {this->stride_[0], this->stride_[1]},
1182 {this->pads_[0], this->pads_[1], this->pads_[2], this->pads_[3]});
1184 int row_offset_size_per_thread = rowOffsetBufferSizeGConv(conv_p);
1185 row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
1195 DoNothing<> doNothingObj{};
1196 if (quantize_groupwise_) {
1197 ReQuantizeOutput<false, QuantizationGranularity::GROUP> reqObj(
1199 requantization_multipliers_.data(),
1200 out_qparams_.zero_point,
1201 in_qparams_[INPUT].zero_point,
1202 filter_zero_points_.data(),
1203 row_offsets_.data() + tid * row_offset_size_per_thread,
1204 column_offsets_->data(),
1209 fbgemmGroupwiseConv(
1211 reinterpret_cast<const uint8_t*>(Xdata),
1212 in_qparams_[INPUT].zero_point,
1213 row_offsets_.data() + tid * row_offset_size_per_thread,
1221 ReQuantizeOutput<false, QuantizationGranularity::TENSOR> reqObj(
1223 requantization_multipliers_.data(),
1224 out_qparams_.zero_point,
1225 in_qparams_[INPUT].zero_point,
1226 filter_zero_points_.data(),
1227 row_offsets_.data() + tid * row_offset_size_per_thread,
1228 column_offsets_->data(),
1233 fbgemmGroupwiseConv(
1235 reinterpret_cast<const uint8_t*>(Xdata),
1236 in_qparams_[INPUT].zero_point,
1237 row_offsets_.data() + tid * row_offset_size_per_thread,
1251 int row_offset_size_per_thread = -1;
1252 int x_pack_buf_size_per_thread = -1;
1254 Wq_packed_ && X.template data<T>() == col_buffer_data && !
IsConvGEMM_();
1257 row_offset_size_per_thread =
1258 PackAWithIm2Col<uint8_t>::rowOffsetBufferSize();
1259 x_pack_buf_size_per_thread = PackAWithIm2Col<uint8_t>::packedBufferSize();
1261 row_offset_size_per_thread =
1262 PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
1263 x_pack_buf_size_per_thread =
1264 PackAWithRowOffset<uint8_t>::packedBufferSize();
1266 row_offsets_.resize(dnnlowp_get_max_threads() * row_offset_size_per_thread);
1267 X_pack_buf_.resize(dnnlowp_get_max_threads() * x_pack_buf_size_per_thread);
1270 uint8_t* Y_uint8_data = Y->template mutable_data<uint8_t>();
1274 #pragma omp parallel 1277 int tid = dnnlowp_get_thread_num();
1281 if (this->kernel_.size() <= 2) {
1282 conv_param_t<> conv_p(
1286 {X.dim32(1), this->kernel_.size() == 2 ? X.dim32(2) : 1},
1289 this->kernel_.size() == 2 ? this->kernel_[1] : 1},
1291 this->kernel_.size() == 2 ? this->stride_[1] : 1},
1293 this->kernel_.size() == 2 ? this->pads_[1] : 0,
1294 this->kernel_.size() == 2 ? this->pads_[2] : this->pads_[1],
1295 this->kernel_.size() == 2 ? this->pads_[3] : 0});
1297 PackAWithIm2Col<uint8_t> packA(
1299 reinterpret_cast<const uint8_t*>(col_buffer_data),
1301 X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1302 in_qparams_[INPUT].zero_point,
1303 row_offsets_.data() + tid * row_offset_size_per_thread);
1305 if (quantize_groupwise_) {
1307 PackAWithIm2Col<uint8_t>,
1308 QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1311 PackAWithIm2Col<uint8_t>,
1312 QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1316 conv_param_t<3> conv_p(
1320 {X.dim32(1), X.dim32(2), X.dim32(3)},
1322 {this->kernel_[0], this->kernel_[1], this->kernel_[2]},
1323 {this->stride_[0], this->stride_[1], this->stride_[2]},
1331 PackAWithIm2Col<uint8_t, int32_t, 3> packA(
1333 reinterpret_cast<const uint8_t*>(col_buffer_data),
1335 X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1336 in_qparams_[INPUT].zero_point,
1337 row_offsets_.data() + tid * row_offset_size_per_thread);
1339 if (quantize_groupwise_) {
1341 PackAWithIm2Col<uint8_t, int32_t, 3>,
1342 QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1345 PackAWithIm2Col<uint8_t, int32_t, 3>,
1346 QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1351 PackAWithRowOffset<uint8_t> packA(
1352 matrix_op_t::NoTranspose,
1354 group_ * kernel_dim,
1355 reinterpret_cast<const uint8_t*>(col_buffer_data),
1356 group_ * kernel_dim,
1358 X_pack_buf_.data() + tid * x_pack_buf_size_per_thread,
1360 row_offsets_.data() + tid * row_offset_size_per_thread);
1362 if (quantize_groupwise_) {
1364 PackAWithRowOffset<uint8_t>,
1365 QuantizationGranularity::GROUP>(packA, Y_int32, Y_uint8_data);
1368 PackAWithRowOffset<uint8_t>,
1369 QuantizationGranularity::TENSOR>(packA, Y_int32, Y_uint8_data);
1373 for (
int group_id = 0; group_id < group_; ++group_id) {
1383 W_quantized_.data(),
1389 template <
typename T,
bool ReluFused>
1392 this->kernel_.size(),
1394 "Only 1-3d convolutions are supported for NHWC storage type");
1398 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1399 chrono::time_point<chrono::system_clock> t_very_begin, t_begin, t_end;
1401 t_begin = chrono::system_clock::now();
1402 t_very_begin = t_begin;
1407 if (!GetQuantizationParameters_()) {
1411 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1413 t_end = chrono::system_clock::now();
1414 double dt = chrono::duration<double>(t_end - t_begin).count();
1415 LOG(INFO) <<
"this=" <<
this <<
" get_quant_params: " << dt * 1e3 <<
" ms";
1419 const Tensor& X = InputTensorCPU_(INPUT);
1420 auto& filter = InputTensorCPU_(FILTER);
1421 const int C = X.dim32(X.dim() - 1);
1422 const int G = group_;
1423 CAFFE_ENFORCE_EQ(X.dim(), filter.dim());
1424 const int M = filter.dim32(0);
1427 filter.dim32(filter.dim() - 1) * G,
1428 "Convolution op: input channels does not match: # of input channels ",
1430 " is not equal to kernel channels * group: ",
1431 filter.dim32(filter.dim() - 1),
1435 M % G, 0,
"The number of output channels is not divisible by group.");
1438 Tensor* Y = OutputTensorCPU_(0, sizes, at::dtype<T>());
1443 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1444 { t_begin = chrono::system_clock::now(); }
1447 bool no_im2col = NoIm2ColNHWC_();
1448 auto f = [&](
Tensor* col_buffer, vector<int32_t>* Y_int32) {
1449 if (!TakeDepthWise3x3FastPath_() && !TakeDepthWise3x3x3FastPath_()) {
1450 Y_int32->resize(Y->numel());
1454 const T* Xdata = X.template data<T>();
1455 const T* col_buffer_data = no_im2col ? Xdata : Im2ColNHWC_(col_buffer);
1457 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1459 t_end = chrono::system_clock::now();
1460 double dt = chrono::duration<double>(t_end - t_begin).count();
1461 LOG(INFO) <<
"this=" <<
this <<
" im2col: " << dt * 1e3 <<
" ms";
1462 t_begin = chrono::system_clock::now();
1466 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1468 t_end = chrono::system_clock::now();
1469 double dt = chrono::duration<double>(t_end - t_begin).count();
1470 LOG(INFO) <<
"this=" <<
this <<
" quantize col_buf: " << dt * 1e3
1472 t_begin = chrono::system_clock::now();
1476 ConvNHWCCore_(col_buffer_data, Y_int32);
1478 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1480 t_end = chrono::system_clock::now();
1481 double dt = chrono::duration<double>(t_end - t_begin).count();
1482 LOG(INFO) <<
"this=" <<
this <<
" GEMM: " << dt * 1e3 <<
" ms";
1483 t_begin = chrono::system_clock::now();
1487 if (Wq_packed_ || Wq_depthwise_3x3_packed_ || Wq_depthwise_3x3x3_packed_ ||
1491 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
1493 RunOnDeviceEpilogueNHWC_(col_buffer_data, Y_int32->data());
1497 this->RunWithSharedBuffer_(&col_buffer_, &Y_int32_, f);
1499 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 1501 const int N = X.dim32(0);
1503 const int kernel_dim = KernelDim_();
1505 const int Y_HxW = this->GetDimsSize(*Y);
1507 t_end = chrono::system_clock::now();
1508 double dt = chrono::duration<double>(t_end - t_begin).count();
1509 LOG(INFO) <<
"this=" <<
this <<
" prologue: " << dt * 1e3 <<
" ms";
1510 t_begin = chrono::system_clock::now();
1512 t_end = chrono::system_clock::now();
1513 const int M = filter.dim32(0);
1514 double ops = 2. * N * Y_HxW * M * kernel_dim;
1515 dt = chrono::duration<double>(t_end - t_very_begin).count();
1516 double gops = ops / dt / 1e9;
1517 LOG(INFO) <<
"this=" <<
this <<
" " << this->debug_def().type()
1518 <<
" output=" << this->debug_def().output(0) <<
" " << N * Y_HxW
1519 <<
"x" << M <<
"x" << kernel_dim <<
" G=" << group_
1520 <<
" C/G=" << C / group_ <<
" K/G=" << M / group_
1521 <<
" R=" << kernel_h() <<
" S=" << kernel_w() <<
" : " << dt * 1e3
1522 <<
" ms " << gops <<
" gops";
1526 MeasureQuantizationError_();
1537 OPERATOR_SCHEMA(
ConvRelu).NumInputs(2, 3).NumOutputs(1).TensorInferenceFunction(
1541 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1546 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1550 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1555 REGISTER_CPU_OPERATOR_WITH_ENGINE(
1559 REGISTER_CPU_OPERATOR_WITH_ENGINE(
bool GetQuantizationParameters_()
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
void Get1DPartitionOf2D(int m, int n, int nthreads, int tid, int *m_begin, int *m_end, int *n_begin, int *n_end, int n_align)
1D-partition m x n 2D work.