Caffe2 - C++ API
A deep learning, cross platform ML framework
LinearAlgebra.cpp
1 #include <ATen/ATen.h>
2 #include <ATen/NativeFunctions.h>
3 #include <ATen/Config.h>
4 
5 #if !AT_MKL_ENABLED()
6 
7 namespace at { namespace native {
8 
9 Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
10  AT_ERROR("bmm: ATen not compiled with MKL support");
11 }
12 
13 }}
14 
15 #else // AT_MKL_ENABLED
16 
17 #include <ATen/ATen.h>
18 #include <ATen/Config.h>
19 #include <ATen/Dispatch.h>
20 #include <ATen/Utils.h>
21 #include <ATen/NativeFunctions.h>
22 
23 #include <algorithm>
24 #include <vector>
25 #include <numeric>
26 #include <cmath>
27 
28 #include <mkl.h>
29 #include <ATen/mkl/Exceptions.h>
30 #include <ATen/mkl/Descriptors.h>
31 #include <ATen/mkl/Limits.h>
32 
33 namespace at { namespace native {
34 
35 static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
36  const int batch_size, const int M, const int N, const int K, const float alpha,
37  const float** A, const float** B, const float beta, float** C) {
38  const int lda = (trans_A == CblasNoTrans) ? K : M;
39  const int ldb = (trans_B == CblasNoTrans) ? N : K;
40  const int ldc = N;
41 
42  cblas_sgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
43  A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
44 }
45 
46 static inline void gemm_batched(const CBLAS_TRANSPOSE trans_A, const CBLAS_TRANSPOSE trans_B,
47  const int batch_size, const int M, const int N, const int K, const double alpha,
48  const double** A, const double** B, const double beta, double** C) {
49  const int lda = (trans_A == CblasNoTrans) ? K : M;
50  const int ldb = (trans_B == CblasNoTrans) ? N : K;
51  const int ldc = N;
52 
53  cblas_dgemm_batch(CblasRowMajor, &trans_A, &trans_B, &M, &N, &K, &alpha,
54  A, &lda, B, &ldb, &beta, C, &ldc, 1, &batch_size);
55 }
56 
57 template <typename scalar_t>
58 static inline void baddbmm_mkl_template(const Tensor& res, const Tensor& mat1, const Tensor& mat2, Scalar beta_, Scalar alpha_) {
59  auto is_transposed = [&](const Tensor& t) {
60  return t.stride(0) == 1 && t.stride(1) == t.size(0);
61  };
62  const CBLAS_TRANSPOSE trans_A = is_transposed(mat1[0]) ? CblasTrans : CblasNoTrans;
63  const CBLAS_TRANSPOSE trans_B = is_transposed(mat2[0]) ? CblasTrans : CblasNoTrans;
64 
65  const int batch_size = mat1.size(0);
66  const int M = mat1.size(1);
67  const int N = mat2.size(2);
68  const int K = mat1.size(2);
69  scalar_t alpha = alpha_.to<scalar_t>();
70  scalar_t beta = beta_.to<scalar_t>();
71 
72  std::vector<const scalar_t*> A(batch_size);
73  std::vector<const scalar_t*> B(batch_size);
74  std::vector<scalar_t*> C(batch_size);
75  for (int64_t batch = 0; batch < batch_size; batch++) {
76  A[batch] = mat1[batch].data<scalar_t>();
77  B[batch] = mat2[batch].data<scalar_t>();
78  C[batch] = res[batch].data<scalar_t>();
79  }
80 
81  gemm_batched(trans_A, trans_B, batch_size, M, N, K, alpha, A.data(), B.data(), beta, C.data());
82 }
83 
84 Tensor& _baddbmm_mkl_(Tensor& self, const Tensor& batch1, const Tensor& batch2, Scalar beta, Scalar alpha) {
85  // checks are done in native/LinearAlgebra.cpp
86  AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "baddbmm__mkl", [&] {
87  baddbmm_mkl_template<scalar_t>(self, batch1, batch2, beta, alpha);
88  });
89 
90  return self;
91 }
92 
93 }} // namespace at::native
94 
95 #endif
Definition: any.cpp:108
Definition: static.cpp:52
Definition: static.cpp:64
Definition: static.cpp:58
Flush-To-Zero and Denormals-Are-Zero mode.