Caffe2 - C++ API
A deep learning, cross platform ML framework
THBlas.cpp
1 #ifndef TH_GENERIC_FILE
2 #define TH_GENERIC_FILE "TH/generic/THBlas.cpp"
3 #else
4 
5 
6 #ifdef BLAS_F2C
7 # define ffloat double
8 #else
9 # define ffloat float
10 #endif
11 
12 TH_EXTERNC void dswap_(int *n, double *x, int *incx, double *y, int *incy);
13 TH_EXTERNC void sswap_(int *n, float *x, int *incx, float *y, int *incy);
14 TH_EXTERNC void dscal_(int *n, double *a, double *x, int *incx);
15 TH_EXTERNC void sscal_(int *n, float *a, float *x, int *incx);
16 TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
17 TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy);
18 TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
19 TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
20 TH_EXTERNC double ddot_(int *n, double *x, int *incx, double *y, int *incy);
21 #ifdef BLAS_USE_CBLAS_DOT
22 TH_EXTERNC float cblas_sdot(const int n, const float *x, const int incx, const float *y, const int incy);
23 #ifndef THBlas_C_sdot_
24 #define THBlas_C_sdot_
25 static inline ffloat sdot_(const int *n, const float *x, const int *incx, const float *y, const int *incy)
26 {
27  return cblas_sdot(*n, x, *incx, y, *incy);
28 }
29 #endif
30 #else
31 TH_EXTERNC ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
32 #endif
33 TH_EXTERNC void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
34 TH_EXTERNC void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
35 TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda);
36 TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda);
37 TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
38 TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);
39 
40 
41 
42 void THBlas_(swap)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
43 {
44  if(n == 1)
45  {
46  incx = 1;
47  incy = 1;
48  }
49 
50 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
51  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
52  {
53  int i_n = (int)n;
54  int i_incx = (int)incx;
55  int i_incy = (int)incy;
56 
57 #if defined(TH_REAL_IS_DOUBLE)
58  dswap_(&i_n, x, &i_incx, y, &i_incy);
59 #else
60  sswap_(&i_n, x, &i_incx, y, &i_incy);
61 #endif
62  return;
63  }
64 #endif
65  {
66  int64_t i;
67  for(i = 0; i < n; i++)
68  {
69  scalar_t z = x[i*incx];
70  x[i*incx] = y[i*incy];
71  y[i*incy] = z;
72  }
73  }
74 }
75 
76 void THBlas_(scal)(int64_t n, scalar_t a, scalar_t *x, int64_t incx)
77 {
78  if(n == 1)
79  incx = 1;
80 
81 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
82  if( (n <= INT_MAX) && (incx <= INT_MAX) )
83  {
84  int i_n = (int)n;
85  int i_incx = (int)incx;
86 
87 #if defined(TH_REAL_IS_DOUBLE)
88  dscal_(&i_n, &a, x, &i_incx);
89 #else
90  sscal_(&i_n, &a, x, &i_incx);
91 #endif
92  return;
93  }
94 #endif
95  {
96  int64_t i;
97  for(i = 0; i < n; i++) {
98  if (a == 0) {
99  x[i*incx] = 0;
100  } else {
101  x[i*incx] *= a;
102  }
103  }
104  }
105 }
106 
107 void THBlas_(copy)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
108 {
109  if(n == 1)
110  {
111  incx = 1;
112  incy = 1;
113  }
114 
115 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
116  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
117  {
118  int i_n = (int)n;
119  int i_incx = (int)incx;
120  int i_incy = (int)incy;
121 
122 #if defined(TH_REAL_IS_DOUBLE)
123  dcopy_(&i_n, x, &i_incx, y, &i_incy);
124 #else
125  scopy_(&i_n, x, &i_incx, y, &i_incy);
126 #endif
127  return;
128  }
129 #endif
130  {
131  int64_t i;
132  for(i = 0; i < n; i++)
133  y[i*incy] = x[i*incx];
134  }
135 }
136 
137 void THBlas_(axpy)(int64_t n, scalar_t a, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
138 {
139  if(n == 1)
140  {
141  incx = 1;
142  incy = 1;
143  }
144 
145 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
146  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
147  {
148  int i_n = (int)n;
149  int i_incx = (int)incx;
150  int i_incy = (int)incy;
151 
152 #if defined(TH_REAL_IS_DOUBLE)
153  daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
154 #else
155  saxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
156 #endif
157  return;
158  }
159 #endif
160  {
161  int64_t i;
162  for(i = 0; i < n; i++)
163  y[i*incy] += a*x[i*incx];
164  }
165 }
166 
167 scalar_t THBlas_(dot)(int64_t n, scalar_t *x, int64_t incx, scalar_t *y, int64_t incy)
168 {
169  if(n == 1)
170  {
171  incx = 1;
172  incy = 1;
173  }
174 
175 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
176  if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
177  {
178  int i_n = (int)n;
179  int i_incx = (int)incx;
180  int i_incy = (int)incy;
181 
182 #if defined(TH_REAL_IS_DOUBLE)
183  return (scalar_t) ddot_(&i_n, x, &i_incx, y, &i_incy);
184 #else
185  return (scalar_t) sdot_(&i_n, x, &i_incx, y, &i_incy);
186 #endif
187  }
188 #endif
189  {
190  int64_t i;
191  scalar_t sum = 0;
192  for(i = 0; i < n; i++)
193  sum += x[i*incx]*y[i*incy];
194  return sum;
195  }
196 }
197 
198 void THBlas_(gemv)(
199  char trans,
200  int64_t m,
201  int64_t n,
202  scalar_t alpha,
203  scalar_t *a,
204  int64_t lda,
205  scalar_t *x,
206  int64_t incx,
207  scalar_t beta,
208  scalar_t *y,
209  int64_t incy)
210 {
211  if(n == 1)
212  lda = m;
213 
214 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
215  if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) &&
216  (incx > 0) && (incx <= INT_MAX) &&
217  (incy > 0) && (incy <= INT_MAX) )
218  {
219  THArgCheck(lda >= THMax(1, m), 6,
220  "lda should be at least max(1, m=%d), but have %d", m, lda);
221  int i_m = (int)m;
222  int i_n = (int)n;
223  int i_lda = (int)lda;
224  int i_incx = (int)incx;
225  int i_incy = (int)incy;
226 
227 #if defined(TH_REAL_IS_DOUBLE)
228  dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
229 #else
230  sgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
231 #endif
232  return;
233  }
234 #endif
235  {
236  int64_t i, j;
237 
238  if( (trans == 'T') || (trans == 't') )
239  {
240  for(i = 0; i < n; i++)
241  {
242  scalar_t sum = 0;
243  scalar_t *row_ = a+lda*i;
244  for(j = 0; j < m; j++)
245  sum += x[j*incx]*row_[j];
246  if (beta == 0)
247  y[i*incy] = alpha*sum;
248  else
249  y[i*incy] = beta*y[i*incy] + alpha*sum;
250  }
251  }
252  else
253  {
254  if(beta != 1)
255  THBlas_(scal)(m, beta, y, incy);
256 
257  for(j = 0; j < n; j++)
258  {
259  scalar_t *column_ = a+lda*j;
260  scalar_t z = alpha*x[j*incx];
261  for(i = 0; i < m; i++)
262  y[i*incy] += z*column_[i];
263  }
264  }
265  }
266 }
267 
268 void THBlas_(ger)(
269  int64_t m,
270  int64_t n,
271  scalar_t alpha,
272  scalar_t *x,
273  int64_t incx,
274  scalar_t *y,
275  int64_t incy,
276  scalar_t *a,
277  int64_t lda)
278 {
279  if(n == 1)
280  lda = m;
281 
282 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
283  if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) &&
284  (incx > 0) && (incx <= INT_MAX) &&
285  (incy > 0) && (incy <= INT_MAX) )
286  {
287  THArgCheck(lda >= THMax(1, m), 9,
288  "lda should be at least max(1, m=%d), but have %d", m, lda);
289  int i_m = (int)m;
290  int i_n = (int)n;
291  int i_lda = (int)lda;
292  int i_incx = (int)incx;
293  int i_incy = (int)incy;
294 
295 #if defined(TH_REAL_IS_DOUBLE)
296  dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
297 #else
298  sger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
299 #endif
300  return;
301  }
302 #endif
303  {
304  int64_t i, j;
305  for(j = 0; j < n; j++)
306  {
307  scalar_t *column_ = a+j*lda;
308  scalar_t z = alpha*y[j*incy];
309  for(i = 0; i < m; i++)
310  column_[i] += z*x[i*incx] ;
311  }
312  }
313 }
314 
315 void THBlas_(gemm)(
316  char transa,
317  char transb,
318  int64_t m,
319  int64_t n,
320  int64_t k,
321  scalar_t alpha,
322  scalar_t *a,
323  int64_t lda,
324  scalar_t *b,
325  int64_t ldb,
326  scalar_t beta,
327  scalar_t *c,
328  int64_t ldc)
329 {
330  int transa_ = ((transa == 't') || (transa == 'T'));
331  int transb_ = ((transb == 't') || (transb == 'T'));
332 
333  if(n == 1)
334  ldc = m;
335 
336  if(transa_)
337  {
338  if(m == 1)
339  lda = k;
340  }
341  else
342  {
343  if(k == 1)
344  lda = m;
345  }
346 
347  if(transb_)
348  {
349  if(k == 1)
350  ldb = n;
351  }
352  else
353  {
354  if(n == 1)
355  ldb = k;
356  }
357 
358 #if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
359  if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) &&
360  (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
361  {
362  THArgCheck(lda >= THMax(1, (transa_ ? k : m)), 8,
363  "lda should be at least max(1, %d), but have %d", (transa_ ? k : m), lda);
364  THArgCheck(ldb >= THMax(1, (transb_ ? n : k)), 10,
365  "ldb should be at least max(1, %d), but have %d", (transb_ ? n : k), ldb);
366  THArgCheck(ldc >= THMax(1, m), 13,
367  "ldc should be at least max(1, m=%d), but have %d", m, ldc);
368  int i_m = (int)m;
369  int i_n = (int)n;
370  int i_k = (int)k;
371  int i_lda = (int)lda;
372  int i_ldb = (int)ldb;
373  int i_ldc = (int)ldc;
374 
375 #if defined(TH_REAL_IS_DOUBLE)
376  dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
377 #else
378  sgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);
379 #endif
380  return;
381  }
382 #endif
383  {
384  if(!transa_ && !transb_)
385  {
386  if (beta == 0) {
387  for (int64_t j = 0; j < n; j++) {
388  for (int64_t i = 0; i < m; i++) {
389  c[j * ldc + i] = 0;
390  }
391  }
392  }
393  else {
394  for (int64_t j = 0; j < n; j++) {
395  for (int64_t i = 0; i < m; i++) {
396  c[j * ldc + i] *= beta;
397  }
398  }
399  }
400  for (int64_t l = 0; l < k; l++) {
401  for (int64_t j = 0; j < n; j++) {
402  scalar_t val = b[l + j * ldb] * alpha;
403  int64_t i_m = m / 4;
404  for (int64_t i_i = 0; i_i < i_m; i_i++) {
405  c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
406  c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
407  c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
408  c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
409  }
410  int64_t i = i_m * 4;
411  for (; i < m; i++)
412  c[j * ldc + i] += a[i + l * lda] * val;
413  }
414  }
415  }
416  else if(transa_ && !transb_)
417  {
418  int64_t i, j, l;
419  scalar_t *a_ = a;
420  for(i = 0; i < m; i++)
421  {
422  scalar_t *b_ = b;
423  for(j = 0; j < n; j++)
424  {
425  scalar_t sum = 0;
426  for(l = 0; l < k; l++)
427  sum += a_[l]*b_[l];
428  b_ += ldb;
429  if (beta == 0)
430  c[j*ldc+i] = alpha*sum;
431  else
432  c[j*ldc+i] = beta*c[j*ldc+i]+alpha*sum;
433  }
434  a_ += lda;
435  }
436  }
437  else if(!transa_ && transb_)
438  {
439  if (beta == 0) {
440  for (int64_t j = 0; j < n; j++) {
441  for (int64_t i = 0; i < m; i++) {
442  c[j * ldc + i] = 0;
443  }
444  }
445  }
446  else {
447  for (int64_t j = 0; j < n; j++) {
448  for (int64_t i = 0; i < m; i++) {
449  c[j * ldc + i] *= beta;
450  }
451  }
452  }
453  for (int64_t l = 0; l < k; l++) {
454  for (int64_t j = 0; j < n; j++) {
455  scalar_t val = b[j + l * ldb] * alpha;
456  int64_t i_m = m / 4;
457  for (int64_t i_i = 0; i_i < i_m; i_i++) {
458  c[j * ldc + i_i * 4 + 0] += a[i_i * 4 + 0 + l * lda] * val;
459  c[j * ldc + i_i * 4 + 1] += a[i_i * 4 + 1 + l * lda] * val;
460  c[j * ldc + i_i * 4 + 2] += a[i_i * 4 + 2 + l * lda] * val;
461  c[j * ldc + i_i * 4 + 3] += a[i_i * 4 + 3 + l * lda] * val;
462  }
463  int64_t i = i_m * 4;
464  for (; i < m; i++)
465  c[j * ldc + i] += a[i + l * lda] * val;
466  }
467  }
468  }
469  else
470  {
471  for (int64_t i = 0; i < m; i++) {
472  for (int64_t j = 0; j < n; j++) {
473  if (beta == 0)
474  c[j * ldc + i] = 0;
475  else
476  c[j * ldc + i] *= beta;
477  }
478  }
479  for (int64_t i = 0; i < m; i++) {
480  for (int64_t j = 0; j < n; j++) {
481  int64_t l_k = k / 4;
482  for (int64_t l_l = 0; l_l < l_k; l_l++) {
483  c[j * ldc + i] += a[i * lda + l_l * 4 + 0] //
484  * b[(l_l * 4 + 0) * ldb + j] * alpha;
485  c[j * ldc + i] += a[i * lda + l_l * 4 + 1] //
486  * b[(l_l * 4 + 1) * ldb + j] * alpha;
487  c[j * ldc + i] += a[i * lda + l_l * 4 + 2] //
488  * b[(l_l * 4 + 2) * ldb + j] * alpha;
489  c[j * ldc + i] += a[i * lda + l_l * 4 + 3] //
490  * b[(l_l * 4 + 3) * ldb + j] * alpha;
491  }
492  int64_t l = l_k * 4;
493  for (; l < k; l++)
494  c[j * ldc + i] += a[i * lda + l] * b[l * ldb + j] * alpha;
495  }
496  }
497  }
498  }
499 }
500 
501 #endif