2 #include <ATen/NativeFunctions.h> 3 #include <ATen/Config.h> 7 namespace at {
namespace native {
10 AT_ERROR(
"bmm: ATen not compiled with MKL support");
15 #else // AT_MKL_ENABLED 17 #include <ATen/ATen.h> 18 #include <ATen/Config.h> 19 #include <ATen/Dispatch.h> 20 #include <ATen/Utils.h> 21 #include <ATen/NativeFunctions.h> 29 #include <ATen/mkl/Exceptions.h> 30 #include <ATen/mkl/Descriptors.h> 31 #include <ATen/mkl/Limits.h> 33 namespace at {
namespace native {
35 static inline void gemm_batched(
const CBLAS_TRANSPOSE trans_A,
const CBLAS_TRANSPOSE trans_B,
36 const int batch_size,
const int M,
const int N,
const int K,
const float alpha,
37 const float**
A,
const float**
B,
const float beta,
float**
C) {
38 const int lda = (trans_A == CblasNoTrans) ? K : M;
39 const int ldb = (trans_B == CblasNoTrans) ? N : K;
42 cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
43 A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
46 static inline void gemm_batched(
const CBLAS_TRANSPOSE trans_A,
const CBLAS_TRANSPOSE trans_B,
47 const int batch_size,
const int M,
const int N,
const int K,
const double alpha,
48 const double**
A,
const double**
B,
const double beta,
double**
C) {
49 const int lda = (trans_A == CblasNoTrans) ? K : M;
50 const int ldb = (trans_B == CblasNoTrans) ? N : K;
53 cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
54 A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
57 template <
typename scalar_t>
58 static inline void baddbmm_mkl_template(
const Tensor& res,
const Tensor& mat1,
const Tensor& mat2, Scalar beta_, Scalar alpha_) {
59 auto is_transposed = [&](
const Tensor& t) {
60 return t.stride(0) == 1 && t.stride(1) == t.size(0);
62 const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
63 const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;
65 const int batch_size = mat1.size(0);
66 const int M = mat1.size(1);
67 const int N = mat2.size(2);
68 const int K = mat1.size(2);
69 scalar_t alpha = alpha_.to<scalar_t>();
70 scalar_t beta = beta_.to<scalar_t>();
72 std::vector<const scalar_t*>
A(batch_size);
73 std::vector<const scalar_t*>
B(batch_size);
74 std::vector<scalar_t*>
C(batch_size);
75 for (int64_t batch = 0; batch < batch_size; batch++) {
76 A[batch] = mat1[batch].data<scalar_t>();
77 B[batch] = mat2[batch].data<scalar_t>();
78 C[batch] = res[batch].data<scalar_t>();
81 gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha,
A.data(),
B.data(), beta,
C.data());
86 AT_DISPATCH_FLOATING_TYPES(
self.scalar_type(),
"baddbmm__mkl", [&] {
87 baddbmm_mkl_template<scalar_t>(
self, batch1, batch2, beta, alpha);
Flush-To-Zero and Denormals-Are-Zero mode.