Caffe2 - C++ API
A deep learning, cross platform ML framework
THLapack.cpp
1 #ifndef TH_GENERIC_FILE
2 #define TH_GENERIC_FILE "TH/generic/THLapack.cpp"
3 #else
4 
5 
6 TH_EXTERNC void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, double *work, int *lwork, int *info);
7 TH_EXTERNC void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, float *work, int *lwork, int *info);
8 TH_EXTERNC void dsyev_(char *jobz, char *uplo, int *n, double *a, int *lda, double *w, double *work, int *lwork, int *info);
9 TH_EXTERNC void ssyev_(char *jobz, char *uplo, int *n, float *a, int *lda, float *w, float *work, int *lwork, int *info);
10 TH_EXTERNC void dgeev_(char *jobvl, char *jobvr, int *n, double *a, int *lda, double *wr, double *wi, double* vl, int *ldvl, double *vr, int *ldvr, double *work, int *lwork, int *info);
11 TH_EXTERNC void sgeev_(char *jobvl, char *jobvr, int *n, float *a, int *lda, float *wr, float *wi, float* vl, int *ldvl, float *vr, int *ldvr, float *work, int *lwork, int *info);
12 TH_EXTERNC void dgesdd_(char *jobz, int *m, int *n, double *a, int *lda, double *s, double *u, int *ldu, double *vt, int *ldvt, double *work, int *lwork, int *iwork, int *info);
13 TH_EXTERNC void sgesdd_(char *jobz, int *m, int *n, float *a, int *lda, float *s, float *u, int *ldu, float *vt, int *ldvt, float *work, int *lwork, int *iwork, int *info);
14 TH_EXTERNC void dgetrf_(int *m, int *n, double *a, int *lda, int *ipiv, int *info);
15 TH_EXTERNC void sgetrf_(int *m, int *n, float *a, int *lda, int *ipiv, int *info);
16 TH_EXTERNC void dgetrs_(char *trans, int *n, int *nrhs, double *a, int *lda, int *ipiv, double *b, int *ldb, int *info);
17 TH_EXTERNC void sgetrs_(char *trans, int *n, int *nrhs, float *a, int *lda, int *ipiv, float *b, int *ldb, int *info);
18 TH_EXTERNC void dgetri_(int *n, double *a, int *lda, int *ipiv, double *work, int *lwork, int *info);
19 TH_EXTERNC void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int *lwork, int *info);
20 TH_EXTERNC void dpotri_(char *uplo, int *n, double *a, int *lda, int *info);
21 TH_EXTERNC void spotri_(char *uplo, int *n, float *a, int *lda, int *info);
22 TH_EXTERNC void sgeqrf_(int *m, int *n, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
23 TH_EXTERNC void dgeqrf_(int *m, int *n, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
24 TH_EXTERNC void sorgqr_(int *m, int *n, int *k, float *a, int *lda, float *tau, float *work, int *lwork, int *info);
25 TH_EXTERNC void dorgqr_(int *m, int *n, int *k, double *a, int *lda, double *tau, double *work, int *lwork, int *info);
26 TH_EXTERNC void sormqr_(char *side, char *trans, int *m, int *n, int *k, float *a, int *lda, float *tau, float *c, int *ldc, float *work, int *lwork, int *info);
27 TH_EXTERNC void dormqr_(char *side, char *trans, int *m, int *n, int *k, double *a, int *lda, double *tau, double *c, int *ldc, double *work, int *lwork, int *info);
28 TH_EXTERNC void spstrf_(char *uplo, int *n, float *a, int *lda, int *piv, int *rank, float *tol, float *work, int *info);
29 TH_EXTERNC void dpstrf_(char *uplo, int *n, double *a, int *lda, int *piv, int *rank, double *tol, double *work, int *info);
30 
31 
32 /* Solve overdetermined or underdetermined real linear systems involving an
33 M-by-N matrix A, or its transpose, using a QR or LQ factorization of A */
34 void THLapack_(gels)(char trans, int m, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, scalar_t *work, int lwork, int *info)
35 {
36 #ifdef USE_LAPACK
37 #if defined(TH_REAL_IS_DOUBLE)
38  dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
39 #else
40  sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, &lwork, info);
41 #endif
42 #else
43  THError("gels : Lapack library not found in compile time\n");
44 #endif
45 }
46 
47 /* Compute all eigenvalues and, optionally, eigenvectors of a real symmetric
48 matrix A */
49 void THLapack_(syev)(char jobz, char uplo, int n, scalar_t *a, int lda, scalar_t *w, scalar_t *work, int lwork, int *info)
50 {
51 #ifdef USE_LAPACK
52 #if defined(TH_REAL_IS_DOUBLE)
53  dsyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
54 #else
55  ssyev_(&jobz, &uplo, &n, a, &lda, w, work, &lwork, info);
56 #endif
57 #else
58  THError("syev : Lapack library not found in compile time\n");
59 #endif
60 }
61 
62 /* Compute for an N-by-N real nonsymmetric matrix A, the eigenvalues and,
63 optionally, the left and/or right eigenvectors */
64 void THLapack_(geev)(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info)
65 {
66 #ifdef USE_LAPACK
67 #if defined(TH_REAL_IS_DOUBLE)
68  dgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
69 #else
70  sgeev_(&jobvl, &jobvr, &n, a, &lda, wr, wi, vl, &ldvl, vr, &ldvr, work, &lwork, info);
71 #endif
72 #else
73  THError("geev : Lapack library not found in compile time\n");
74 #endif
75 }
76 
77 /* Compute the singular value decomposition (SVD) of a real M-by-N matrix A,
78 optionally computing the left and/or right singular vectors */
79 void THLapack_(gesdd)(char jobz, int m, int n, scalar_t *a, int lda, scalar_t *s, scalar_t *u, int ldu, scalar_t *vt, int ldvt, scalar_t *work, int lwork, int *iwork, int *info)
80 {
81 #ifdef USE_LAPACK
82 #if defined(TH_REAL_IS_DOUBLE)
83  dgesdd_( &jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
84 #else
85  sgesdd_( &jobz, &m, &n, a, &lda, s, u, &ldu, vt, &ldvt, work, &lwork, iwork, info);
86 #endif
87 #else
88  THError("gesdd : Lapack library not found in compile time\n");
89 #endif
90 }
91 
92 /* LU decomposition */
93 void THLapack_(getrf)(int m, int n, scalar_t *a, int lda, int *ipiv, int *info)
94 {
95 #ifdef USE_LAPACK
96 #if defined(TH_REAL_IS_DOUBLE)
97  dgetrf_(&m, &n, a, &lda, ipiv, info);
98 #else
99  sgetrf_(&m, &n, a, &lda, ipiv, info);
100 #endif
101 #else
102  THError("getrf : Lapack library not found in compile time\n");
103 #endif
104 }
105 
106 void THLapack_(getrs)(char trans, int n, int nrhs, scalar_t *a, int lda, int *ipiv, scalar_t *b, int ldb, int *info)
107 {
108 #ifdef USE_LAPACK
109 #if defined(TH_REAL_IS_DOUBLE)
110  dgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
111 #else
112  sgetrs_(&trans, &n, &nrhs, a, &lda, ipiv, b, &ldb, info);
113 #endif
114 #else
115  THError("getrs : Lapack library not found in compile time\n");
116 #endif
117 }
118 
119 /* Matrix Inverse */
120 void THLapack_(getri)(int n, scalar_t *a, int lda, int *ipiv, scalar_t *work, int lwork, int* info)
121 {
122 #ifdef USE_LAPACK
123 #if defined(TH_REAL_IS_DOUBLE)
124  dgetri_(&n, a, &lda, ipiv, work, &lwork, info);
125 #else
126  sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
127 #endif
128 #else
129  THError("getri : Lapack library not found in compile time\n");
130 #endif
131 }
132 
133 /* Cholesky factorization based Matrix Inverse */
134 void THLapack_(potri)(char uplo, int n, scalar_t *a, int lda, int *info)
135 {
136 #ifdef USE_LAPACK
137 #if defined(TH_REAL_IS_DOUBLE)
138  dpotri_(&uplo, &n, a, &lda, info);
139 #else
140  spotri_(&uplo, &n, a, &lda, info);
141 #endif
142 #else
143  THError("potri: Lapack library not found in compile time\n");
144 #endif
145 }
146 
147 /* Cholesky factorization with complete pivoting */
148 void THLapack_(pstrf)(char uplo, int n, scalar_t *a, int lda, int *piv, int *rank, scalar_t tol, scalar_t *work, int *info)
149 {
150 #ifdef USE_LAPACK
151 #if defined(TH_REAL_IS_DOUBLE)
152  dpstrf_(&uplo, &n, a, &lda, piv, rank, &tol, work, info);
153 #else
154  spstrf_(&uplo, &n, a, &lda, piv, rank, &tol, work, info);
155 #endif
156 #else
157  THError("pstrf: Lapack library not found at compile time\n");
158 #endif
159 }
160 
161 /* QR decomposition */
162 void THLapack_(geqrf)(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info)
163 {
164 #ifdef USE_LAPACK
165 #if defined(TH_REAL_IS_DOUBLE)
166  dgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
167 #else
168  sgeqrf_(&m, &n, a, &lda, tau, work, &lwork, info);
169 #endif
170 #else
171  THError("geqrf: Lapack library not found in compile time\n");
172 #endif
173 }
174 
175 /* Build Q from output of geqrf */
176 void THLapack_(orgqr)(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info)
177 {
178 #ifdef USE_LAPACK
179 #if defined(TH_REAL_IS_DOUBLE)
180  dorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
181 #else
182  sorgqr_(&m, &n, &k, a, &lda, tau, work, &lwork, info);
183 #endif
184 #else
185  THError("orgqr: Lapack library not found in compile time\n");
186 #endif
187 }
188 
189 /* Multiply Q with a matrix using the output of geqrf */
190 void THLapack_(ormqr)(char side, char trans, int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *c, int ldc, scalar_t *work, int lwork, int *info)
191 {
192 #ifdef USE_LAPACK
193 #if defined(TH_REAL_IS_DOUBLE)
194  dormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
195 #else
196  sormqr_(&side, &trans, &m, &n, &k, a, &lda, tau, c, &ldc, work, &lwork, info);
197 #endif
198 #else
199  THError("ormqr: Lapack library not found in compile time\n");
200 #endif
201 }
202 
203 
204 #endif