1 #ifndef TH_GENERIC_FILE 2 #define TH_GENERIC_FILE "TH/generic/THBlas.cpp" 12 TH_EXTERNC
void dswap_(
int *n,
double *x,
int *incx,
double *y,
int *incy);
13 TH_EXTERNC
void sswap_(
int *n,
float *x,
int *incx,
float *y,
int *incy);
14 TH_EXTERNC
void dscal_(
int *n,
double *a,
double *x,
int *incx);
15 TH_EXTERNC
void sscal_(
int *n,
float *a,
float *x,
int *incx);
16 TH_EXTERNC
void dcopy_(
int *n,
double *x,
int *incx,
double *y,
int *incy);
17 TH_EXTERNC
void scopy_(
int *n,
float *x,
int *incx,
float *y,
int *incy);
18 TH_EXTERNC
void daxpy_(
int *n,
double *a,
double *x,
int *incx,
double *y,
int *incy);
19 TH_EXTERNC
void saxpy_(
int *n,
float *a,
float *x,
int *incx,
float *y,
int *incy);
20 TH_EXTERNC
double ddot_(
int *n,
double *x,
int *incx,
double *y,
int *incy);
21 #ifdef BLAS_USE_CBLAS_DOT 22 TH_EXTERNC
float cblas_sdot(
const int n,
const float *x,
const int incx,
const float *y,
const int incy);
23 #ifndef THBlas_C_sdot_ 24 #define THBlas_C_sdot_ 25 static inline ffloat sdot_(
const int *n,
const float *x,
const int *incx,
const float *y,
const int *incy)
27 return cblas_sdot(*n, x, *incx, y, *incy);
31 TH_EXTERNC ffloat sdot_(
int *n,
float *x,
int *incx,
float *y,
int *incy);
33 TH_EXTERNC
void dgemv_(
char *trans,
int *m,
int *n,
double *alpha,
double *a,
int *lda,
double *x,
int *incx,
double *beta,
double *y,
int *incy);
34 TH_EXTERNC
void sgemv_(
char *trans,
int *m,
int *n,
float *alpha,
float *a,
int *lda,
float *x,
int *incx,
float *beta,
float *y,
int *incy);
35 TH_EXTERNC
void dger_(
int *m,
int *n,
double *alpha,
double *x,
int *incx,
double *y,
int *incy,
double *a,
int *lda);
36 TH_EXTERNC
void sger_(
int *m,
int *n,
float *alpha,
float *x,
int *incx,
float *y,
int *incy,
float *a,
int *lda);
37 TH_EXTERNC
void dgemm_(
char *transa,
char *transb,
int *m,
int *n,
int *k,
double *alpha,
double *a,
int *lda,
double *b,
int *ldb,
double *beta,
double *c,
int *ldc);
38 TH_EXTERNC
void sgemm_(
char *transa,
char *transb,
int *m,
int *n,
int *k,
float *alpha,
float *a,
int *lda,
float *b,
int *ldb,
float *beta,
float *c,
int *ldc);
42 void THBlas_(swap)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
50 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 51 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
54 int i_incx = (int)incx;
55 int i_incy = (int)incy;
57 #if defined(TH_REAL_IS_DOUBLE) 58 dswap_(&i_n, x, &i_incx, y, &i_incy);
60 sswap_(&i_n, x, &i_incx, y, &i_incy);
67 for(i = 0; i < n; i++)
69 scalar_t z = x[i*incx];
70 x[i*incx] = y[i*incy];
76 void THBlas_(scal)(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
81 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 82 if( (n <= INT_MAX) && (incx <= INT_MAX) )
85 int i_incx = (int)incx;
87 #if defined(TH_REAL_IS_DOUBLE) 88 dscal_(&i_n, &a, x, &i_incx);
90 sscal_(&i_n, &a, x, &i_incx);
97 for(i = 0; i < n; i++) {
107 void THBlas_(copy)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
115 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 116 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
119 int i_incx = (int)incx;
120 int i_incy = (int)incy;
122 #if defined(TH_REAL_IS_DOUBLE) 123 dcopy_(&i_n, x, &i_incx, y, &i_incy);
125 scopy_(&i_n, x, &i_incx, y, &i_incy);
132 for(i = 0; i < n; i++)
133 y[i*incy] = x[i*incx];
137 void THBlas_(axpy)(int64_t n, scalar_t a, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
145 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 146 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
149 int i_incx = (int)incx;
150 int i_incy = (int)incy;
152 #if defined(TH_REAL_IS_DOUBLE) 153 daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
155 saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
162 for(i = 0; i < n; i++)
163 y[i*incy] += a*x[i*incx];
167 scalar_t THBlas_(dot)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
175 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 176 if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
179 int i_incx = (int)incx;
180 int i_incy = (int)incy;
182 #if defined(TH_REAL_IS_DOUBLE) 183 return (scalar_t) ddot_(&i_n, x, &i_incx, y, &i_incy);
185 return (scalar_t) sdot_(&i_n, x, &i_incx, y, &i_incy);
192 for(i = 0; i < n; i++)
193 sum += x[i*incx]*y[i*incy];
214 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 215 if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) &&
216 (incx > 0) && (incx <= INT_MAX) &&
217 (incy > 0) && (incy <= INT_MAX) )
219 THArgCheck(lda >= THMax(1, m), 6,
220 "lda should be at least max(1, m=%d), but have %d", m, lda);
223 int i_lda = (int)lda;
224 int i_incx = (int)incx;
225 int i_incy = (int)incy;
227 #if defined(TH_REAL_IS_DOUBLE) 228 dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
230 sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
238 if( (trans ==
'T') || (trans ==
't') )
240 for(i = 0; i < n; i++)
243 scalar_t *row_ = a+lda*i;
244 for(j = 0; j < m; j++)
245 sum += x[j*incx]*row_[j];
247 y[i*incy] = alpha*sum;
249 y[i*incy] = beta*y[i*incy] + alpha*sum;
255 THBlas_(scal)(m, beta, y, incy);
257 for(j = 0; j < n; j++)
259 scalar_t *column_ = a+lda*j;
260 scalar_t z = alpha*x[j*incx];
261 for(i = 0; i < m; i++)
262 y[i*incy] += z*column_[i];
282 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 283 if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) &&
284 (incx > 0) && (incx <= INT_MAX) &&
285 (incy > 0) && (incy <= INT_MAX) )
287 THArgCheck(lda >= THMax(1, m), 9,
288 "lda should be at least max(1, m=%d), but have %d", m, lda);
291 int i_lda = (int)lda;
292 int i_incx = (int)incx;
293 int i_incy = (int)incy;
295 #if defined(TH_REAL_IS_DOUBLE) 296 dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
298 sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
305 for(j = 0; j < n; j++)
307 scalar_t *column_ = a+j*lda;
308 scalar_t z = alpha*y[j*incy];
309 for(i = 0; i < m; i++)
310 column_[i] += z*x[i*incx] ;
330 int transa_ = ((transa ==
't') || (transa ==
'T'));
331 int transb_ = ((transb ==
't') || (transb ==
'T'));
358 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT)) 359 if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
360 (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
362 THArgCheck(lda >= THMax(1, (transa_ ? k : m)), 8,
363 "lda should be at least max(1, %d), but have %d", (transa_ ? k : m), lda);
364 THArgCheck(ldb >= THMax(1, (transb_ ? n : k)), 10,
365 "ldb should be at least max(1, %d), but have %d", (transb_ ? n : k), ldb);
366 THArgCheck(ldc >= THMax(1, m), 13,
367 "ldc should be at least max(1, m=%d), but have %d", m, ldc);
371 int i_lda = (int)lda;
372 int i_ldb = (int)ldb;
373 int i_ldc = (int)ldc;
375 #if defined(TH_REAL_IS_DOUBLE) 376 dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
378 sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
384 if(!transa_ && !transb_)
387 for (int64_t j = 0; j < n; j++) {
388 for (int64_t i = 0; i < m; i++) {
394 for (int64_t j = 0; j < n; j++) {
395 for (int64_t i = 0; i < m; i++) {
396 c[j * ldc + i] *= beta;
400 for (int64_t l = 0; l < k; l++) {
401 for (int64_t j = 0; j < n; j++) {
402 scalar_t val = b[l + j * ldb] * alpha;
404 for (int64_t i_i = 0; i_i < i_m; i_i++) {
405 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
406 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
407 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
408 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
412 c[j * ldc + i] += a[i + l * lda] * val;
416 else if(transa_ && !transb_)
420 for(i = 0; i < m; i++)
423 for(j = 0; j < n; j++)
426 for(l = 0; l < k; l++)
430 c[j*ldc+i] = alpha*sum;
432 c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
437 else if(!transa_ && transb_)
440 for (int64_t j = 0; j < n; j++) {
441 for (int64_t i = 0; i < m; i++) {
447 for (int64_t j = 0; j < n; j++) {
448 for (int64_t i = 0; i < m; i++) {
449 c[j * ldc + i] *= beta;
453 for (int64_t l = 0; l < k; l++) {
454 for (int64_t j = 0; j < n; j++) {
455 scalar_t val = b[j + l * ldb] * alpha;
457 for (int64_t i_i = 0; i_i < i_m; i_i++) {
458 c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
459 c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
460 c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
461 c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
465 c[j * ldc + i] += a[i + l * lda] * val;
471 for (int64_t i = 0; i < m; i++) {
472 for (int64_t j = 0; j < n; j++) {
476 c[j * ldc + i] *= beta;
479 for (int64_t i = 0; i < m; i++) {
480 for (int64_t j = 0; j < n; j++) {
482 for (int64_t l_l = 0; l_l < l_k; l_l++) {
483 c[j * ldc + i] += a[i * lda + l_l * 4 + 0]
484 * b[(l_l * 4 + 0) * ldb + j] * alpha;
485 c[j * ldc + i] += a[i * lda + l_l * 4 + 1]
486 * b[(l_l * 4 + 1) * ldb + j] * alpha;
487 c[j * ldc + i] += a[i * lda + l_l * 4 + 2]
488 * b[(l_l * 4 + 2) * ldb + j] * alpha;
489 c[j * ldc + i] += a[i * lda + l_l * 4 + 3]
490 * b[(l_l * 4 + 3) * ldb + j] * alpha;
494 c[j * ldc + i] += a[i * lda + l] * b[l * ldb + j] * alpha;