17 #include "batch_matmul_dnnlowp_op.h" 24 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 34 BatchMatMulDNNLowPOp<T>::BatchMatMulDNNLowPOp(
35 const OperatorDef& operator_def,
37 : BaseType(operator_def, ws),
38 trans_a_(this->template GetSingleArgument<int>(
"trans_a", 0)),
39 trans_b_(this->template GetSingleArgument<int>(
"trans_b", 0)),
40 broadcast_(this->template GetSingleArgument<int>(
"broadcast", 0)),
42 this->template GetSingleArgument<bool>(
"constant_B", false)) {}
45 bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
46 this->ParseDNNLowPOperatorArguments_();
48 const auto&
A = InputTensorCPU_(0);
49 const auto&
B = InputTensorCPU_(1);
50 auto* Y = OutputTensorCPU_(0);
52 auto ndims_A =
A.ndim();
53 auto dims_A =
A.sizes().vec();
54 auto ndims_B =
B.ndim();
55 auto dims_B =
B.sizes().vec();
57 auto noBroadcastErrorMsg = [](
size_t dim1,
size_t dim2) {
59 ss <<
"Inputs with dimensions A = ";
63 ss <<
" is not supported with broadcast=0. Did you forget to set the " 69 bool dimMismatch = ndims_A != ndims_B;
70 bool dimsLessThan1D = ndims_A < 2;
72 broadcast_ || (!dimMismatch && !dimsLessThan1D),
73 noBroadcastErrorMsg(ndims_A, ndims_B));
75 auto dimMismatchErrorString = [](
size_t dimnum1,
82 ss <<
"Expected dimension ";
84 ss <<
" of tensor A with value ";
86 ss <<
" to match dimension ";
88 ss <<
" of tensor B with value ";
97 int num_sub_batches, num_outer_batches;
102 if (ndims_A == 1 && ndims_B == 1) {
107 "Vector-vector product requires each of the vectors to " 108 "be the same size.");
111 num_outer_batches = 1;
116 bool A_broadcasted =
false, B_broadcasted =
false;
118 dims_A.insert(dims_A.begin(), 1);
120 A_broadcasted =
true;
125 B_broadcasted =
true;
138 size_t num_inner_dims = std::min(ndims_A, ndims_B);
139 for (
size_t i = 2; i < num_inner_dims; ++i) {
140 auto first_r_itr = dims_A.rbegin();
141 auto second_r_itr = dims_B.rbegin();
145 dimMismatchErrorString(
153 size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
159 M = dims_A[ndims_A - 1];
160 K = dims_A[ndims_A - 2];
163 M = dims_A[ndims_A - 2];
164 K = dims_A[ndims_A - 1];
168 N = dims_B[ndims_B - 2];
172 dimMismatchErrorString(
173 K_dim, K, ndims_B - 1, dims_B[ndims_B - 1], trans_a_, trans_b_));
175 N = dims_B[ndims_B - 1];
179 dimMismatchErrorString(
180 K_dim, K, ndims_B - 2, dims_B[ndims_B - 2], trans_a_, trans_b_));
186 std::vector<int64_t> new_dims;
187 if (ndims_A >= ndims_B) {
188 new_dims.assign(dims_A.begin(), dims_A.end() - 2);
190 new_dims.assign(dims_B.begin(), dims_B.end() - 2);
192 if (!A_broadcasted) {
193 new_dims.push_back(M);
195 new_dims.push_back(1);
197 if (!B_broadcasted) {
198 new_dims.push_back(N);
200 new_dims.push_back(1);
218 if (ndims_A >= ndims_B) {
219 auto first_r_itr = dims_A.rbegin();
220 auto output_r_itr = new_dims.rbegin();
221 for (
size_t i = 0; i < num_inner_dims; ++i) {
222 A_stride *= *(first_r_itr + i);
223 Y_stride *= *(output_r_itr + i);
225 num_sub_batches *= *(first_r_itr + i);
231 auto second_r_itr = dims_B.rbegin();
232 auto output_r_itr = new_dims.rbegin();
233 for (
size_t i = 0; i < num_inner_dims; ++i) {
234 B_stride *= *(second_r_itr + i);
235 Y_stride *= *(output_r_itr + i);
237 num_sub_batches *= *(second_r_itr + i);
242 num_outer_batches = 1;
243 for (
size_t i = 0; i < num_outer_dims; ++i) {
244 num_outer_batches *= new_dims[i];
250 new_dims.erase(new_dims.end() - 2);
251 }
else if (B_broadcasted) {
252 new_dims.erase(new_dims.end() - 1);
260 if (num_sub_batches == 1 && num_outer_batches > 1) {
261 if (ndims_A > ndims_B && !trans_a_) {
262 M *= num_outer_batches;
263 num_outer_batches = 1;
269 if (num_sub_batches == 0 || num_outer_batches == 0) {
270 if (dequantize_output_) {
271 Y->template mutable_data<float>();
273 Y->template mutable_data<T>();
279 in_qparams_[0] = GetInputTensorQuantizationParamsOf(
this, 0, qfactory_.get());
280 int num_batches_B =
B.numel() / (K * N);
281 if (!first_invocation_ && !Bq_packed_.empty() &&
282 num_batches_B * N != column_offsets_.size()) {
283 LOG(INFO) <<
"Operator with output " << this->debug_def().output(0)
284 <<
" does not have constant B";
285 is_B_constant_ =
false;
289 std::is_same<T, uint8_t>::value && GetCpuId().avx2() && is_B_constant_;
293 if (Bq_packed_.empty()) {
294 int signed_min = -(1 << (qfactory_->GetWeightPrecision() - 1));
295 vector<int8_t> B_quantized_temp(K * N);
296 column_offsets_.resize(num_batches_B * N);
297 for (
int i = 0; i < num_batches_B; ++i) {
298 if (this->
template InputIsType<int8::Int8TensorCPU>(1)) {
299 B_qparams_.push_back(TensorQuantizationParams());
300 B_qparams_[i].scale =
301 this->
template Input<int8::Int8TensorCPU>(1).scale;
302 B_qparams_[i].zero_point =
303 this->
template Input<int8::Int8TensorCPU>(1).zero_point +
306 const T* B_data =
B.template data<T>() + i * B_quantized_temp.size();
307 for (
auto j = 0; j < B_quantized_temp.size(); ++j) {
308 B_quantized_temp[j] = B_data[j] + signed_min;
311 B_qparams_.emplace_back(qfactory_->ChooseQuantizationParams(
312 B.template data<float>() + i * B_quantized_temp.size(),
313 B_quantized_temp.size(),
318 B_qparams_[i].zero_point += signed_min;
320 fbgemm::Quantize<int8_t>(
321 B.template data<float>() + i * B_quantized_temp.size(),
322 B_quantized_temp.data(),
323 B_quantized_temp.size(),
327 Bq_packed_.emplace_back(
new fbgemm::PackBMatrix<int8_t>(
328 trans_b_ ? fbgemm::matrix_op_t::Transpose
329 : fbgemm::matrix_op_t::NoTranspose,
332 B_quantized_temp.data(),
338 for (
int j = 0; j < N; ++j) {
341 for (
int k = 0; k < K; ++k) {
342 sum += B_quantized_temp[j * K + k];
345 for (
int k = 0; k < K; ++k) {
346 sum += B_quantized_temp[k * N + j];
349 column_offsets_[i * N + j] = sum - B_qparams_[i].zero_point * K;
354 if (!dequantize_output_) {
355 GetOutputQuantizationParams_();
357 for (
int i = 0; i < num_batches_B; ++i) {
358 float real_multiplier =
359 in_qparams_[0].scale * B_qparams_[i].scale / out_qparams_.scale;
360 requantization_params_.emplace_back(
361 qfactory_->ChooseRequantizationMultiplier(
362 real_multiplier, out_qparams_));
365 if (measure_quantization_error_) {
367 Fp32Op_()->DequantizeInput();
368 Fp32Op_()->Get()->RunOnDevice();
373 if (first_invocation_) {
375 if (!is_same<T, uint8_t>::value) {
376 reason =
"fbgemm only supports 8-bit integers";
377 }
else if (!GetCpuId().avx2()) {
378 reason =
"fbgemm only supports AVX2";
379 }
else if (!is_B_constant_) {
380 reason =
"B is not constant";
384 LOG(WARNING) <<
"BatchMatMul with output " << this->debug_def().output(0)
385 <<
" falls back to slow path because " << reason;
387 B_qparams_.resize(1);
388 requantization_params_.resize(1);
391 GetInputTensorQuantizationParamsOf(
this, 1, qfactory_.get());
393 GetOutputQuantizationParams_();
395 float real_multiplier =
396 in_qparams_[0].scale * B_qparams_[0].scale / out_qparams_.scale;
397 requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
398 real_multiplier, out_qparams_);
401 first_invocation_ =
false;
403 vector<T> A_temp, B_temp;
404 if (!Bq_packed_.empty()) {
408 const T* A_quantized =
nullptr;
409 if (
A.template IsType<T>() || !dequantize_output_) {
412 A_quantized = QuantizeInputIfNeeded<T>(
this, 0, in_qparams_[0], A_temp);
415 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 416 chrono::time_point<chrono::system_clock> t_begin, t_end;
417 t_begin = chrono::system_clock::now();
420 if (!dequantize_output_) {
421 auto Y_data = Y->template mutable_data<T>();
423 auto row_offset_len_per_thread =
424 PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
426 row_offset_len_per_thread * dnnlowp_get_max_threads());
427 auto A_pack_buf_len_per_thread =
428 PackAWithRowOffset<uint8_t>::packedBufferSize();
429 A_pack_buf_.resize(A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
430 Y_int32_.resize(Y->numel());
433 #pragma omp parallel for collapse(2) 435 for (
int p = 0; p < num_outer_batches; ++p) {
436 for (
int i = 0; i < num_sub_batches; ++i) {
437 int tid = dnnlowp_get_thread_num();
439 PackAWithRowOffset<uint8_t> packA(
440 trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
443 reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
447 tid * A_pack_buf_len_per_thread,
449 row_offsets_.data() + tid * row_offset_len_per_thread);
451 int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
452 DoNothing<> doNothingObj{};
453 ReQuantizeOutput<
false > outputProcObj(
455 &requantization_params_[B_batch_idx].real_multiplier,
456 out_qparams_.zero_point,
457 in_qparams_[0].zero_point,
458 &B_qparams_[B_batch_idx].zero_point,
459 packA.getRowOffsetBuffer(),
460 column_offsets_.data() + B_batch_idx * N,
466 *Bq_packed_[B_batch_idx],
467 reinterpret_cast<uint8_t*>(Y_data) + p * Y_stride + i * M * N,
468 Y_int32_.data() + p * Y_stride + i * M * N,
476 PropagateOutputTensorQuantizationParams(
this, 0, out_qparams_);
479 float* Y_data = Y->template mutable_data<float>();
481 if (!
A.template IsType<T>()) {
483 int row_offset_len_per_thread =
484 PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize();
486 row_offset_len_per_thread * dnnlowp_get_max_threads());
487 int A_pack_len_per_thread =
488 PackAWithQuantRowOffset<uint8_t>::packedBufferSize();
489 A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());
492 #pragma omp parallel for collapse(2) 494 for (
int p = 0; p < num_outer_batches; ++p) {
495 for (
int i = 0; i < num_sub_batches; ++i) {
496 int tid = dnnlowp_get_thread_num();
498 PackAWithQuantRowOffset<uint8_t> packA(
499 trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
502 A.template data<float>() + p * A_stride + i * M * K,
505 tid * A_pack_len_per_thread,
506 in_qparams_[0].scale,
507 in_qparams_[0].zero_point,
509 row_offsets_.data() + tid * row_offset_len_per_thread);
511 int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
512 DoNothing<float, float> doNothingObj{};
513 ReQuantizeForFloat<
false > outputProcObj(
515 in_qparams_[0].scale,
516 &B_qparams_[B_batch_idx].scale,
517 in_qparams_[0].zero_point,
518 &B_qparams_[B_batch_idx].zero_point,
519 packA.getRowOffsetBuffer(),
520 column_offsets_.data() + B_batch_idx * N,
526 *Bq_packed_[B_batch_idx],
527 Y_data + p * Y_stride + i * M * N,
528 reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
537 auto row_offset_len_per_thread =
538 PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
540 row_offset_len_per_thread * dnnlowp_get_max_threads());
541 auto A_pack_buf_len_per_thread =
542 PackAWithRowOffset<uint8_t>::packedBufferSize();
544 A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
547 #pragma omp parallel for collapse(2) 549 for (
int p = 0; p < num_outer_batches; ++p) {
550 for (
int i = 0; i < num_sub_batches; ++i) {
551 int tid = dnnlowp_get_thread_num();
553 PackAWithRowOffset<uint8_t> packA(
554 trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
557 reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
561 tid * A_pack_buf_len_per_thread,
563 row_offsets_.data() + tid * row_offset_len_per_thread);
565 int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
566 DoNothing<float, float> doNothingObj{};
567 ReQuantizeForFloat<
false > outputProcObj(
569 in_qparams_[0].scale,
570 &B_qparams_[B_batch_idx].scale,
571 in_qparams_[0].zero_point,
572 &B_qparams_[B_batch_idx].zero_point,
573 packA.getRowOffsetBuffer(),
574 column_offsets_.data() + B_batch_idx * N,
580 *Bq_packed_[B_batch_idx],
581 Y_data + p * Y_stride + i * M * N,
582 reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
592 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN 593 t_end = chrono::system_clock::now();
594 double dt = chrono::duration<double>(t_end - t_begin).count();
596 2. * num_outer_batches * num_sub_batches * M * N * K / dt / 1e9;
597 LOG(INFO) <<
"batches " << num_outer_batches * num_sub_batches <<
" m " << M
598 <<
" n " << N <<
" k " << K <<
" " << gops <<
" gops";
601 MeasureQuantizationError_();
605 const T* A_quantized =
606 QuantizeInputIfNeeded<T>(
this, 0, in_qparams_[0], A_temp);
607 const T* B_quantized =
608 QuantizeInputIfNeeded<T>(
this, 1, B_qparams_[0], B_temp);
610 T* Y_quantized = GetQuantizedOutputData_();
611 Y_int32_.resize(Y->numel());
613 #pragma omp parallel for collapse(2) 615 for (
int p = 0; p < num_outer_batches; ++p) {
616 for (
int i = 0; i < num_sub_batches; ++i) {
633 const T* A_quantized_i = A_quantized + p * A_stride + i * M * K;
634 const T* B_quantized_i = B_quantized + p * B_stride + i * K * N;
636 int32_t const_offset =
637 in_qparams_[0].zero_point * B_qparams_[0].zero_point * K;
638 vector<int32_t> column_offsets(N);
639 for (
int n = 0; n < N; ++n) {
642 for (
int k = 0; k < K; ++k) {
643 sum += B_quantized_i[k + n * K];
646 for (
int k = 0; k < K; ++k) {
647 sum += B_quantized_i[k * N + n];
650 column_offsets[n] = sum * in_qparams_[0].zero_point;
653 for (
int m = 0; m < M; ++m) {
654 int32_t row_offset = 0;
656 for (
int k = 0; k < K; ++k) {
657 row_offset += A_quantized_i[m + k * M];
660 for (
int k = 0; k < K; ++k) {
661 row_offset += A_quantized_i[m * K + k];
664 row_offset *= B_qparams_[0].zero_point;
666 for (
int n = 0; n < N; ++n) {
668 if (!trans_a_ && !trans_b_) {
669 for (
int k = 0; k < K; ++k) {
670 sum +=
static_cast<int32_t
>(A_quantized_i[m * K + k]) *
671 static_cast<int32_t>(B_quantized_i[k * N + n]);
673 }
else if (!trans_a_ && trans_b_) {
674 for (
int k = 0; k < K; ++k) {
675 sum +=
static_cast<int32_t
>(A_quantized_i[m * K + k]) *
676 static_cast<int32_t>(B_quantized_i[k + n * K]);
678 }
else if (trans_a_ && !trans_b_) {
679 for (
int k = 0; k < K; ++k) {
680 sum +=
static_cast<int32_t
>(A_quantized_i[m + k * M]) *
681 static_cast<int32_t>(B_quantized_i[k * N + n]);
683 }
else if (trans_a_ && trans_b_) {
684 for (
int k = 0; k < K; ++k) {
685 sum +=
static_cast<int32_t
>(A_quantized_i[m + k * M]) *
686 static_cast<int32_t>(B_quantized_i[k + n * K]);
690 Y_int32_[p * Y_stride + i * M * N + m * N + n] =
691 sum - row_offset - column_offsets[n] + const_offset;
696 for (
int j = 0; j < M * N; ++j) {
697 Y_quantized[p * Y_stride + i * M * N + j] = fbgemm::Requantize<T>(
698 Y_int32_[p * Y_stride + i * M * N + j],
699 requantization_params_[0]);
704 RunOnDeviceEpilogue_();
710 REGISTER_CPU_OPERATOR_WITH_ENGINE(
713 BatchMatMulDNNLowPOp<uint8_t>);
714 REGISTER_CPU_OPERATOR_WITH_ENGINE(
717 BatchMatMulDNNLowPOp<uint16_t>);
719 REGISTER_CPU_OPERATOR_WITH_ENGINE(
722 BatchMatMulDNNLowPOp<uint8_t>);
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
does bound shape inference given a C2 net.