Caffe2 - C++ API
A deep learning, cross platform ML framework
batch_matmul_dnnlowp_op.cc
1 
17 #include "batch_matmul_dnnlowp_op.h"
18 
19 #ifdef _OPENMP
20 #include <omp.h>
21 #endif
22 
23 // #define DNNLOWP_MEASURE_TIME_BREAKDOWN
24 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
25 #include <chrono>
26 #endif
27 
28 namespace caffe2 {
29 
30 using namespace std;
31 using namespace dnnlowp;
32 
33 template <typename T>
34 BatchMatMulDNNLowPOp<T>::BatchMatMulDNNLowPOp(
35  const OperatorDef& operator_def,
36  Workspace* ws)
37  : BaseType(operator_def, ws),
38  trans_a_(this->template GetSingleArgument<int>("trans_a", 0)),
39  trans_b_(this->template GetSingleArgument<int>("trans_b", 0)),
40  broadcast_(this->template GetSingleArgument<int>("broadcast", 0)),
41  is_B_constant_(
42  this->template GetSingleArgument<bool>("constant_B", false)) {}
43 
44 template <typename T>
45 bool BatchMatMulDNNLowPOp<T>::RunOnDevice() {
46  this->ParseDNNLowPOperatorArguments_();
47 
48  const auto& A = InputTensorCPU_(0);
49  const auto& B = InputTensorCPU_(1);
50  auto* Y = OutputTensorCPU_(0);
51 
52  auto ndims_A = A.ndim();
53  auto dims_A = A.sizes().vec();
54  auto ndims_B = B.ndim();
55  auto dims_B = B.sizes().vec();
56 
57  auto noBroadcastErrorMsg = [](size_t dim1, size_t dim2) {
58  std::stringstream ss;
59  ss << "Inputs with dimensions A = ";
60  ss << dim1;
61  ss << " and B = ";
62  ss << dim2;
63  ss << " is not supported with broadcast=0. Did you forget to set the "
64  "broadcast flag?";
65  return ss.str();
66  };
67 
68  // These should all be false if we're not broadcasting.
69  bool dimMismatch = ndims_A != ndims_B;
70  bool dimsLessThan1D = ndims_A < 2;
71  CAFFE_ENFORCE(
72  broadcast_ || (!dimMismatch && !dimsLessThan1D),
73  noBroadcastErrorMsg(ndims_A, ndims_B));
74 
75  auto dimMismatchErrorString = [](size_t dimnum1,
76  size_t dim1,
77  size_t dimnum2,
78  size_t dim2,
79  bool trans_a,
80  bool trans_b) {
81  std::stringstream ss;
82  ss << "Expected dimension ";
83  ss << dimnum1;
84  ss << " of tensor A with value ";
85  ss << dim1;
86  ss << " to match dimension ";
87  ss << dimnum2;
88  ss << " of tensor B with value ";
89  ss << dim2;
90  ss << ". trans_a = ";
91  ss << trans_a;
92  ss << " trans_b = ";
93  ss << trans_b;
94  return ss.str();
95  };
96 
97  int num_sub_batches, num_outer_batches;
98  size_t M, N, K;
99  size_t A_stride = 1; // How far to increment A pointer each itr
100  size_t B_stride = 1; // How far to increment B pointer each itr
101  size_t Y_stride = 1; // How far to increment Y pointer each itr
102  if (ndims_A == 1 && ndims_B == 1) {
103  // vector-vector
104  CAFFE_ENFORCE_EQ(
105  dims_A[0],
106  dims_B[0],
107  "Vector-vector product requires each of the vectors to "
108  "be the same size.");
109  Y->Resize(1);
110  num_sub_batches = 1;
111  num_outer_batches = 1;
112  M = 1;
113  N = 1;
114  K = dims_A[0];
115  } else {
116  bool A_broadcasted = false, B_broadcasted = false;
117  if (ndims_A == 1) {
118  dims_A.insert(dims_A.begin(), 1);
119  ndims_A = 2;
120  A_broadcasted = true;
121  }
122  if (ndims_B == 1) {
123  dims_B.push_back(1);
124  ndims_B = 2;
125  B_broadcasted = true;
126  }
127  // matrix-matrix with batches
128  // [B1..., M, K] * [B2..., K, N] -> [B..., M, N]
129  // In the event that A or B are one-dimensional, the trailing or leading
130  // 1 is not added to the output tensor's size.
131 
132  // First step: partition the tensors into inner and outer blocks.
133  // Ignoring the last two dimensions of A and B, ensure that one of the
134  // tensors' dimensions is a suffix of the other. For example,
135  // [4, x, x] is a suffix of [2, 3, 4, x, x]. In this example, the
136  // dimensions of size 2 and 3 will be broadcasted, so we partition into
137  // 2*3=6 individual instances of batched GEMM with A and B \in [4, x, x].
138  size_t num_inner_dims = std::min(ndims_A, ndims_B);
139  for (size_t i = 2; i < num_inner_dims; ++i) {
140  auto first_r_itr = dims_A.rbegin();
141  auto second_r_itr = dims_B.rbegin();
142  CAFFE_ENFORCE_EQ(
143  *(first_r_itr + i),
144  *(second_r_itr + i),
145  dimMismatchErrorString(
146  ndims_A - i - 1,
147  *(first_r_itr + i),
148  ndims_B - i - 1,
149  *(second_r_itr + i),
150  trans_a_,
151  trans_b_));
152  }
153  size_t num_outer_dims = std::max(ndims_A, ndims_B) - num_inner_dims;
154 
155  // Standard M, N, and K parameters respecting GEMM API and transpose
156  // flags
157  size_t K_dim;
158  if (trans_a_) {
159  M = dims_A[ndims_A - 1];
160  K = dims_A[ndims_A - 2];
161  K_dim = ndims_A - 2;
162  } else {
163  M = dims_A[ndims_A - 2];
164  K = dims_A[ndims_A - 1];
165  K_dim = ndims_A - 1;
166  }
167  if (trans_b_) {
168  N = dims_B[ndims_B - 2];
169  CAFFE_ENFORCE_EQ(
170  K,
171  dims_B[ndims_B - 1],
172  dimMismatchErrorString(
173  K_dim, K, ndims_B - 1, dims_B[ndims_B - 1], trans_a_, trans_b_));
174  } else {
175  N = dims_B[ndims_B - 1];
176  CAFFE_ENFORCE_EQ(
177  K,
178  dims_B[ndims_B - 2],
179  dimMismatchErrorString(
180  K_dim, K, ndims_B - 2, dims_B[ndims_B - 2], trans_a_, trans_b_));
181  }
182 
183  // Calculate output tensor shapes [B..., (M), (N)]
184  // Batch dimensions will be broadcasted out to those of the longer tensor
185  // A or B. Either M or N are optional if A or B, respectively are 1-D.
186  std::vector<int64_t> new_dims;
187  if (ndims_A >= ndims_B) {
188  new_dims.assign(dims_A.begin(), dims_A.end() - 2);
189  } else {
190  new_dims.assign(dims_B.begin(), dims_B.end() - 2);
191  }
192  if (!A_broadcasted) {
193  new_dims.push_back(M);
194  } else {
195  new_dims.push_back(1);
196  }
197  if (!B_broadcasted) {
198  new_dims.push_back(N);
199  } else {
200  new_dims.push_back(1);
201  }
202 
203  // Calculate strides. Continuing our example above,
204  // [4, M, K] * [2, 3, 4, K, N] = [2, 3, 4, M, N]
205  // We calculate this as follows:
206  // 1) Treat the outer batch dimensions as flattened, i.e. view the B
207  // tensor here as [6, 4, K, N] and Y as [6, 4, M, N]. The same rea-
208  // soning is analogous for the case where # dims A >= # dims B.
209  // 2) Perform this operation:
210  // for i in range(6):
211  // Y[i, :, :, :] = BatchMatMul(A, B[i, :, :, :])
212  A_stride = 1; // How far to increment A pointer each itr
213  B_stride = 1; // How far to increment B pointer each itr
214  Y_stride = 1; // How far to increment Y pointer each itr
215  // How many "inner batches" we have. That is, the product of sizes for
216  // the slices excluding M, K, and N, for their respective matrices.
217  num_sub_batches = 1;
218  if (ndims_A >= ndims_B) {
219  auto first_r_itr = dims_A.rbegin();
220  auto output_r_itr = new_dims.rbegin();
221  for (size_t i = 0; i < num_inner_dims; ++i) {
222  A_stride *= *(first_r_itr + i);
223  Y_stride *= *(output_r_itr + i);
224  if (i >= 2) {
225  num_sub_batches *= *(first_r_itr + i);
226  }
227  }
228  B_stride = 0;
229  } else {
230  A_stride = 0;
231  auto second_r_itr = dims_B.rbegin();
232  auto output_r_itr = new_dims.rbegin();
233  for (size_t i = 0; i < num_inner_dims; ++i) {
234  B_stride *= *(second_r_itr + i);
235  Y_stride *= *(output_r_itr + i);
236  if (i >= 2) {
237  num_sub_batches *= *(second_r_itr + i);
238  }
239  }
240  }
241 
242  num_outer_batches = 1;
243  for (size_t i = 0; i < num_outer_dims; ++i) {
244  num_outer_batches *= new_dims[i];
245  }
246 
247  // Mutually exclusive since otherwise we would've taken the vector-vector
248  // path above
249  if (A_broadcasted) {
250  new_dims.erase(new_dims.end() - 2);
251  } else if (B_broadcasted) {
252  new_dims.erase(new_dims.end() - 1);
253  }
254 
255  // Allocate output tensor
256  Y->Resize(new_dims);
257 
258  // Optimize case num_sub_batches == 1 where we can combine batched gemms
259  // into a single gemm
260  if (num_sub_batches == 1 && num_outer_batches > 1) {
261  if (ndims_A > ndims_B && !trans_a_) {
262  M *= num_outer_batches;
263  num_outer_batches = 1;
264  }
265  }
266  }
267 
268  // Zero batch dimension indicates no elements
269  if (num_sub_batches == 0 || num_outer_batches == 0) {
270  if (dequantize_output_) {
271  Y->template mutable_data<float>();
272  } else {
273  Y->template mutable_data<T>();
274  }
275  return true;
276  }
277 
278  // Choose quantization for X
279  in_qparams_[0] = GetInputTensorQuantizationParamsOf(this, 0, qfactory_.get());
280  int num_batches_B = B.numel() / (K * N);
281  if (!first_invocation_ && !Bq_packed_.empty() &&
282  num_batches_B * N != column_offsets_.size()) {
283  LOG(INFO) << "Operator with output " << this->debug_def().output(0)
284  << " does not have constant B";
285  is_B_constant_ = false;
286  Bq_packed_.clear();
287  }
288  bool fast_path =
289  std::is_same<T, uint8_t>::value && GetCpuId().avx2() && is_B_constant_;
290 
291  if (fast_path) {
292  // Quantize B
293  if (Bq_packed_.empty()) {
294  int signed_min = -(1 << (qfactory_->GetWeightPrecision() - 1));
295  vector<int8_t> B_quantized_temp(K * N);
296  column_offsets_.resize(num_batches_B * N);
297  for (int i = 0; i < num_batches_B; ++i) {
298  if (this->template InputIsType<int8::Int8TensorCPU>(1)) {
299  B_qparams_.push_back(TensorQuantizationParams());
300  B_qparams_[i].scale =
301  this->template Input<int8::Int8TensorCPU>(1).scale;
302  B_qparams_[i].zero_point =
303  this->template Input<int8::Int8TensorCPU>(1).zero_point +
304  signed_min;
305 
306  const T* B_data = B.template data<T>() + i * B_quantized_temp.size();
307  for (auto j = 0; j < B_quantized_temp.size(); ++j) {
308  B_quantized_temp[j] = B_data[j] + signed_min;
309  }
310  } else {
311  B_qparams_.emplace_back(qfactory_->ChooseQuantizationParams(
312  B.template data<float>() + i * B_quantized_temp.size(),
313  B_quantized_temp.size(),
314  true /* weight */));
315 
316  // B_qparams_[i] is computed for unsigned type.
317  // Adjust for the fact that B will actually use signed.
318  B_qparams_[i].zero_point += signed_min;
319 
320  fbgemm::Quantize<int8_t>(
321  B.template data<float>() + i * B_quantized_temp.size(),
322  B_quantized_temp.data(),
323  B_quantized_temp.size(),
324  B_qparams_[i]);
325  }
326 
327  Bq_packed_.emplace_back(new fbgemm::PackBMatrix<int8_t>(
328  trans_b_ ? fbgemm::matrix_op_t::Transpose
329  : fbgemm::matrix_op_t::NoTranspose,
330  K,
331  N,
332  B_quantized_temp.data(),
333  trans_b_ ? K : N,
334  nullptr /*pmat*/,
335  1)); /*groups*/
336 
337  // Pre-compute column_offset
338  for (int j = 0; j < N; ++j) {
339  int32_t sum = 0;
340  if (trans_b_) {
341  for (int k = 0; k < K; ++k) {
342  sum += B_quantized_temp[j * K + k];
343  }
344  } else {
345  for (int k = 0; k < K; ++k) {
346  sum += B_quantized_temp[k * N + j];
347  }
348  }
349  column_offsets_[i * N + j] = sum - B_qparams_[i].zero_point * K;
350  }
351  } // for each input in the batch
352  } // Bq_packed_.empty()
353 
354  if (!dequantize_output_) {
355  GetOutputQuantizationParams_();
356 
357  for (int i = 0; i < num_batches_B; ++i) {
358  float real_multiplier =
359  in_qparams_[0].scale * B_qparams_[i].scale / out_qparams_.scale;
360  requantization_params_.emplace_back(
361  qfactory_->ChooseRequantizationMultiplier(
362  real_multiplier, out_qparams_));
363  }
364  } else {
365  if (measure_quantization_error_) {
366  // to measure quantization error, run ref impl.
367  Fp32Op_()->DequantizeInput();
368  Fp32Op_()->Get()->RunOnDevice();
369  }
370  }
371  } else {
372  // slow path
373  if (first_invocation_) {
374  string reason;
375  if (!is_same<T, uint8_t>::value) {
376  reason = "fbgemm only supports 8-bit integers";
377  } else if (!GetCpuId().avx2()) {
378  reason = "fbgemm only supports AVX2";
379  } else if (!is_B_constant_) {
380  reason = "B is not constant";
381  } else {
382  assert(false);
383  }
384  LOG(WARNING) << "BatchMatMul with output " << this->debug_def().output(0)
385  << " falls back to slow path because " << reason;
386  }
387  B_qparams_.resize(1);
388  requantization_params_.resize(1);
389 
390  B_qparams_[0] =
391  GetInputTensorQuantizationParamsOf(this, 1, qfactory_.get());
392 
393  GetOutputQuantizationParams_();
394 
395  float real_multiplier =
396  in_qparams_[0].scale * B_qparams_[0].scale / out_qparams_.scale;
397  requantization_params_[0] = qfactory_->ChooseRequantizationMultiplier(
398  real_multiplier, out_qparams_);
399  }
400 
401  first_invocation_ = false;
402 
403  vector<T> A_temp, B_temp;
404  if (!Bq_packed_.empty()) {
405  // fast path
406  using namespace fbgemm;
407 
408  const T* A_quantized = nullptr;
409  if (A.template IsType<T>() || !dequantize_output_) {
410  // Only when input and output are float, we don't need input to be
411  // quantized.
412  A_quantized = QuantizeInputIfNeeded<T>(this, 0, in_qparams_[0], A_temp);
413  }
414 
415 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
416  chrono::time_point<chrono::system_clock> t_begin, t_end;
417  t_begin = chrono::system_clock::now();
418 #endif
419 
420  if (!dequantize_output_) {
421  auto Y_data = Y->template mutable_data<T>();
422 
423  auto row_offset_len_per_thread =
424  PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
425  row_offsets_.resize(
426  row_offset_len_per_thread * dnnlowp_get_max_threads());
427  auto A_pack_buf_len_per_thread =
428  PackAWithRowOffset<uint8_t>::packedBufferSize();
429  A_pack_buf_.resize(A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
430  Y_int32_.resize(Y->numel());
431 
432 #ifdef _OPENMP
433 #pragma omp parallel for collapse(2)
434 #endif
435  for (int p = 0; p < num_outer_batches; ++p) {
436  for (int i = 0; i < num_sub_batches; ++i) {
437  int tid = dnnlowp_get_thread_num();
438 
439  PackAWithRowOffset<uint8_t> packA(
440  trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
441  M,
442  K,
443  reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
444  i * M * K,
445  trans_a_ ? M : K,
446  A_pack_buf_.data() +
447  tid * A_pack_buf_len_per_thread, // buffer for packed matrix
448  1, // group
449  row_offsets_.data() + tid * row_offset_len_per_thread);
450 
451  int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
452  DoNothing<> doNothingObj{};
453  ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
454  doNothingObj,
455  &requantization_params_[B_batch_idx].real_multiplier,
456  out_qparams_.zero_point,
457  in_qparams_[0].zero_point,
458  &B_qparams_[B_batch_idx].zero_point,
459  packA.getRowOffsetBuffer(),
460  column_offsets_.data() + B_batch_idx * N,
461  nullptr, // bias
462  N); // ncols per quant group
463 
464  fbgemmPacked(
465  packA,
466  *Bq_packed_[B_batch_idx],
467  reinterpret_cast<uint8_t*>(Y_data) + p * Y_stride + i * M * N,
468  Y_int32_.data() + p * Y_stride + i * M * N,
469  N,
470  outputProcObj,
471  0, // thread_id
472  1); // num_threads
473  } // for each input in batch
474  }
475 
476  PropagateOutputTensorQuantizationParams(this, 0, out_qparams_);
477  } else {
478  // dequantize_output
479  float* Y_data = Y->template mutable_data<float>();
480 
481  if (!A.template IsType<T>()) {
482  // Both input and output are float
483  int row_offset_len_per_thread =
484  PackAWithQuantRowOffset<uint8_t>::rowOffsetBufferSize();
485  row_offsets_.resize(
486  row_offset_len_per_thread * dnnlowp_get_max_threads());
487  int A_pack_len_per_thread =
488  PackAWithQuantRowOffset<uint8_t>::packedBufferSize();
489  A_pack_buf_.resize(A_pack_len_per_thread * dnnlowp_get_max_threads());
490 
491 #ifdef _OPENMP
492 #pragma omp parallel for collapse(2)
493 #endif
494  for (int p = 0; p < num_outer_batches; ++p) {
495  for (int i = 0; i < num_sub_batches; ++i) {
496  int tid = dnnlowp_get_thread_num();
497 
498  PackAWithQuantRowOffset<uint8_t> packA(
499  trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
500  M,
501  K,
502  A.template data<float>() + p * A_stride + i * M * K,
503  trans_a_ ? M : K,
504  A_pack_buf_.data() +
505  tid * A_pack_len_per_thread, // buffer for packed matrix
506  in_qparams_[0].scale,
507  in_qparams_[0].zero_point,
508  1, // groups
509  row_offsets_.data() + tid * row_offset_len_per_thread);
510 
511  int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
512  DoNothing<float, float> doNothingObj{};
513  ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
514  doNothingObj,
515  in_qparams_[0].scale,
516  &B_qparams_[B_batch_idx].scale,
517  in_qparams_[0].zero_point,
518  &B_qparams_[B_batch_idx].zero_point,
519  packA.getRowOffsetBuffer(),
520  column_offsets_.data() + B_batch_idx * N,
521  nullptr, // bias
522  N); // ncols per quant group
523 
524  fbgemmPacked(
525  packA,
526  *Bq_packed_[B_batch_idx],
527  Y_data + p * Y_stride + i * M * N,
528  reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
529  N,
530  outputProcObj,
531  0, // thread_id
532  1); // num_threads
533  } // for each input in batch
534  }
535  } else {
536  // Input quantized and output float
537  auto row_offset_len_per_thread =
538  PackAWithRowOffset<uint8_t>::rowOffsetBufferSize();
539  row_offsets_.resize(
540  row_offset_len_per_thread * dnnlowp_get_max_threads());
541  auto A_pack_buf_len_per_thread =
542  PackAWithRowOffset<uint8_t>::packedBufferSize();
543  A_pack_buf_.resize(
544  A_pack_buf_len_per_thread * dnnlowp_get_max_threads());
545 
546 #ifdef _OPENMP
547 #pragma omp parallel for collapse(2)
548 #endif
549  for (int p = 0; p < num_outer_batches; ++p) {
550  for (int i = 0; i < num_sub_batches; ++i) {
551  int tid = dnnlowp_get_thread_num();
552 
553  PackAWithRowOffset<uint8_t> packA(
554  trans_a_ ? matrix_op_t::Transpose : matrix_op_t::NoTranspose,
555  M,
556  K,
557  reinterpret_cast<const uint8_t*>(A_quantized) + p * A_stride +
558  i * M * K,
559  trans_a_ ? M : K,
560  A_pack_buf_.data() +
561  tid * A_pack_buf_len_per_thread, // buffer for packed matrix
562  1, // group
563  row_offsets_.data() + tid * row_offset_len_per_thread);
564 
565  int B_batch_idx = ndims_A >= ndims_B ? i : p * num_sub_batches + i;
566  DoNothing<float, float> doNothingObj{};
567  ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
568  doNothingObj,
569  in_qparams_[0].scale,
570  &B_qparams_[B_batch_idx].scale,
571  in_qparams_[0].zero_point,
572  &B_qparams_[B_batch_idx].zero_point,
573  packA.getRowOffsetBuffer(),
574  column_offsets_.data() + B_batch_idx * N,
575  nullptr, // bias
576  N); // ncols per quant group
577 
578  fbgemmPacked(
579  packA,
580  *Bq_packed_[B_batch_idx],
581  Y_data + p * Y_stride + i * M * N,
582  reinterpret_cast<int32_t*>(Y_data) + p * Y_stride + i * M * N,
583  N,
584  outputProcObj,
585  0, // thread_id
586  1); // num_threads
587  } // for each input in batch
588  }
589  }
590  } // dequantize_output
591 
592 #ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
593  t_end = chrono::system_clock::now();
594  double dt = chrono::duration<double>(t_end - t_begin).count();
595  double gops =
596  2. * num_outer_batches * num_sub_batches * M * N * K / dt / 1e9;
597  LOG(INFO) << "batches " << num_outer_batches * num_sub_batches << " m " << M
598  << " n " << N << " k " << K << " " << gops << " gops";
599 #endif
600 
601  MeasureQuantizationError_();
602  } else {
603  // slow path
604  // Quantize inputs
605  const T* A_quantized =
606  QuantizeInputIfNeeded<T>(this, 0, in_qparams_[0], A_temp);
607  const T* B_quantized =
608  QuantizeInputIfNeeded<T>(this, 1, B_qparams_[0], B_temp);
609 
610  T* Y_quantized = GetQuantizedOutputData_();
611  Y_int32_.resize(Y->numel());
612 #ifdef _OPENMP
613 #pragma omp parallel for collapse(2)
614 #endif
615  for (int p = 0; p < num_outer_batches; ++p) {
616  for (int i = 0; i < num_sub_batches; ++i) {
617  // Y_q = (scale_A * scale_B) / scale_Y * Y_int32
618  // Y_int32 = (A_q - zero_point_A * 1_A) * (B_q - zero_point_B * 1_B),
619  // where 1_A is a matrix with all 1s and same size as A
620  // Y_int32 = A_q * B_q
621  // - zero_point_A * 1_A * B - zero_point_B * A * 1_B
622  // + zero_point_A * zero_point_B * 1_A * 1_B
623  // zero_point_A * 1_A * B : a matrix with (i, j) is the sum of jth
624  // column of B. This is computed by
625  // column_offsets in the code.
626  // zero_point_B * A * 1_B : a matrix with (i, j) is the sum of ith row
627  // of A. This is computed by row_offset in the
628  // code.
629  // zero_point_A * zero_point_B * 1_A * 1_B : a matrix with all elements
630  // are zero_point_A * zero_point_B *
631  // num_of_cols_of_A. This is computed by
632  // const_offset in the code.
633  const T* A_quantized_i = A_quantized + p * A_stride + i * M * K;
634  const T* B_quantized_i = B_quantized + p * B_stride + i * K * N;
635 
636  int32_t const_offset =
637  in_qparams_[0].zero_point * B_qparams_[0].zero_point * K;
638  vector<int32_t> column_offsets(N);
639  for (int n = 0; n < N; ++n) {
640  int32_t sum = 0;
641  if (trans_b_) {
642  for (int k = 0; k < K; ++k) {
643  sum += B_quantized_i[k + n * K];
644  }
645  } else {
646  for (int k = 0; k < K; ++k) {
647  sum += B_quantized_i[k * N + n];
648  }
649  }
650  column_offsets[n] = sum * in_qparams_[0].zero_point;
651  }
652 
653  for (int m = 0; m < M; ++m) {
654  int32_t row_offset = 0;
655  if (trans_a_) {
656  for (int k = 0; k < K; ++k) {
657  row_offset += A_quantized_i[m + k * M];
658  }
659  } else {
660  for (int k = 0; k < K; ++k) {
661  row_offset += A_quantized_i[m * K + k];
662  }
663  }
664  row_offset *= B_qparams_[0].zero_point;
665 
666  for (int n = 0; n < N; ++n) {
667  int32_t sum = 0;
668  if (!trans_a_ && !trans_b_) {
669  for (int k = 0; k < K; ++k) {
670  sum += static_cast<int32_t>(A_quantized_i[m * K + k]) *
671  static_cast<int32_t>(B_quantized_i[k * N + n]);
672  }
673  } else if (!trans_a_ && trans_b_) {
674  for (int k = 0; k < K; ++k) {
675  sum += static_cast<int32_t>(A_quantized_i[m * K + k]) *
676  static_cast<int32_t>(B_quantized_i[k + n * K]);
677  }
678  } else if (trans_a_ && !trans_b_) {
679  for (int k = 0; k < K; ++k) {
680  sum += static_cast<int32_t>(A_quantized_i[m + k * M]) *
681  static_cast<int32_t>(B_quantized_i[k * N + n]);
682  }
683  } else if (trans_a_ && trans_b_) {
684  for (int k = 0; k < K; ++k) {
685  sum += static_cast<int32_t>(A_quantized_i[m + k * M]) *
686  static_cast<int32_t>(B_quantized_i[k + n * K]);
687  }
688  }
689 
690  Y_int32_[p * Y_stride + i * M * N + m * N + n] =
691  sum - row_offset - column_offsets[n] + const_offset;
692  } // for each output col
693  } // for each output row
694 
695  // Requantization
696  for (int j = 0; j < M * N; ++j) {
697  Y_quantized[p * Y_stride + i * M * N + j] = fbgemm::Requantize<T>(
698  Y_int32_[p * Y_stride + i * M * N + j],
699  requantization_params_[0]);
700  }
701  } // for each batch
702  }
703 
704  RunOnDeviceEpilogue_();
705  }
706 
707  return true;
708 }
709 
710 REGISTER_CPU_OPERATOR_WITH_ENGINE(
711  BatchMatMul,
712  DNNLOWP,
713  BatchMatMulDNNLowPOp<uint8_t>);
714 REGISTER_CPU_OPERATOR_WITH_ENGINE(
715  BatchMatMul,
716  DNNLOWP_16,
717  BatchMatMulDNNLowPOp<uint16_t>);
718 
719 REGISTER_CPU_OPERATOR_WITH_ENGINE(
720  Int8BatchMatMul,
721  DNNLOWP,
722  BatchMatMulDNNLowPOp<uint8_t>);
723 
724 } // namespace caffe2
Definition: any.cpp:108
A global dictionary that holds information about what Caffe2 modules have been loaded in the current ...
Definition: blob.h:13
does bound shape inference given a C2 net.
Definition: static.cpp:58