1 #ifndef TH_GENERIC_FILE 2 #define TH_GENERIC_FILE "TH/generic/THTensorLapack.cpp" 8 static int THTensor_(isTransposedContiguous)(THTensor *
self)
10 return self->stride(0) == 1 &&
self->stride(1) ==
self->size(0);
16 static void THTensor_(checkTransposed)(THTensor *
self)
18 if(THTensor_(isContiguous)(
self))
19 THTensor_(transpose)(
self, NULL, 0, 1);
27 static THTensor *THTensor_(newTransposedContiguous)(THTensor *
self)
30 if(THTensor_(isTransposedContiguous)(
self))
32 THTensor_(retain)(
self);
37 tensor = THTensor_(newContiguous)(
self);
38 THTensor_(transpose)(tensor, NULL, 0, 1);
53 static THTensor *THTensor_(checkLapackClone)(THTensor *result, THTensor *src,
int nrows)
56 if (src == result && THTensor_(isTransposedContiguous)(src) && src->size(1) == nrows)
57 THTensor_(retain)(result);
58 else if(src == result || result == NULL)
59 result = THTensor_(
new)();
61 THTensor_(retain)(result);
69 static THTensor *THTensor_(cloneColumnMajorNrows)(THTensor *
self, THTensor *src,
int nrows)
76 result = THTensor_(checkLapackClone)(
self, src, nrows);
80 THTensor_(resize2d)(result, src->size(1), nrows);
81 THTensor_(checkTransposed)(result);
83 if (src->size(0) == nrows) {
84 at::Tensor result_wrap = THTensor_wrap(result);
86 at::_copy_same_type_(result_wrap, src_wrap);
90 view = THTensor_(newNarrow)(result, 0, 0, src->size(0));
93 at::_copy_same_type_(view_wrap, src_wrap);
94 c10::raw::intrusive_ptr::decref(view);
104 static THTensor *THTensor_(cloneColumnMajor)(THTensor *
self, THTensor *src)
106 return THTensor_(cloneColumnMajorNrows)(
self, src, src->size(0));
109 void THTensor_(gels)(THTensor *rb_, THTensor *ra_, THTensor *b, THTensor *a)
113 if (a == NULL) a = ra_;
114 if (b == NULL) b = rb_;
115 THArgCheck(a->dim() == 2, 2,
"A should have 2 dimensions, but has %d",
117 THArgCheck(!a->is_empty(), 2,
"A should not be empty");
118 THArgCheck(b->dim() == 1 || b->dim() == 2, 1,
"B should have 1 or 2 " 119 "dimensions, but has %d", b->dim());
120 THArgCheck(!b->is_empty(), 1,
"B should not be empty");
121 AT_CHECK(a->size(0) == b->size(0),
"Expected A and b to have same size " 122 "at dim 0, but A has ", a->size(0),
" rows and B has ", b->size(0),
" rows");
124 if (THTensor_nDimensionLegacyAll(b) == 1) {
125 b = THTensor_(newWithStorage2d)(THTensor_getStoragePtr(b), b->storage_offset(), b->size(0),
130 int m, n, nrhs, lda, ldb, info, lwork;
131 THTensor *work = NULL;
134 THTensor *ra__ = NULL;
135 THTensor *rb__ = NULL;
137 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
142 ldb = (m > n) ? m : n;
144 rb__ = THTensor_(cloneColumnMajorNrows)(rb_, b, ldb);
146 nrhs = rb__->size(1);
151 THLapack_(gels)(
'N', m, n, nrhs, ra__->data<scalar_t>(), lda,
152 rb__->data<scalar_t>(), ldb,
155 work = THTensor_(newWithSize1d)(lwork);
156 THLapack_(gels)(
'N', m, n, nrhs, ra__->data<scalar_t>(), lda,
157 rb__->data<scalar_t>(), ldb,
158 work->data<scalar_t>(), lwork, &info);
160 THLapackCheckWithCleanup(
"Lapack Error in %s : The %d-th diagonal element of the triangular factor of A is zero",
161 THCleanup(c10::raw::intrusive_ptr::decref(ra__);
162 c10::raw::intrusive_ptr::decref(rb__);
163 c10::raw::intrusive_ptr::decref(work);
164 if (free_b) c10::raw::intrusive_ptr::decref(b);),
172 if (m < n && b == rb_) {
173 THTensor_(resize2d)(rb_, n, nrhs);
176 THTensor_(freeCopyTo)(ra__, ra_);
177 THTensor_(freeCopyTo)(rb__, rb_);
178 c10::raw::intrusive_ptr::decref(work);
179 if (free_b) c10::raw::intrusive_ptr::decref(b);
182 void THTensor_(geev)(THTensor *re_, THTensor *rv_, THTensor *a_,
const char *jobvr)
184 int n, lda, lwork, info, ldvr;
185 THTensor *work=
nullptr, *wi, *wr, *a;
190 THTensor *re__ = NULL;
191 THTensor *rv__ = NULL;
193 THArgCheck(a_->dim() == 2, 1,
"A should be 2 dimensional");
194 THArgCheck(a_->size(0) == a_->size(1), 1,
"A should be square");
197 a = THTensor_(cloneColumnMajor)(NULL, a_);
202 wi = THTensor_(newWithSize1d)(n);
203 wr = THTensor_(newWithSize1d)(n);
209 THTensor_(resize2d)(rv_,n,n);
211 rv__ = THTensor_(newTransposedContiguous)(rv_);
212 rv_data = rv__->data<scalar_t>();
215 THTensor_(resize2d)(re_,n,2);
216 re__ = THTensor_(newContiguous)(re_);
220 THLapack_(geev)(
'N', jobvr[0], n, a->data<scalar_t>(), lda, wr->data<scalar_t>(), wi->data<scalar_t>(),
221 NULL, 1, rv_data, ldvr, &wkopt, -1, &info);
224 work = THTensor_(newWithSize1d)(lwork);
226 THLapack_(geev)(
'N', jobvr[0], n, a->data<scalar_t>(), lda, wr->data<scalar_t>(), wi->data<scalar_t>(),
227 NULL, 1, rv_data, ldvr, work->data<scalar_t>(), lwork, &info);
229 THLapackCheckWithCleanup(
" Lapack Error in %s : %d off-diagonal elements of an didn't converge to zero",
230 THCleanup(c10::raw::intrusive_ptr::decref(re__);
231 c10::raw::intrusive_ptr::decref(rv__);
232 c10::raw::intrusive_ptr::decref(a);
233 c10::raw::intrusive_ptr::decref(wi);
234 c10::raw::intrusive_ptr::decref(wr);
235 c10::raw::intrusive_ptr::decref(work);),
240 scalar_t *re_data = re__->data<scalar_t>();
241 scalar_t *wi_data = wi->data<scalar_t>();
242 scalar_t *wr_data = wr->data<scalar_t>();
245 re_data[2*i] = wr_data[i];
246 re_data[2*i+1] = wi_data[i];
252 THTensor_(checkTransposed)(rv_);
253 THTensor_(freeCopyTo)(rv__, rv_);
255 THTensor_(freeCopyTo)(re__, re_);
256 c10::raw::intrusive_ptr::decref(a);
257 c10::raw::intrusive_ptr::decref(wi);
258 c10::raw::intrusive_ptr::decref(wr);
259 c10::raw::intrusive_ptr::decref(work);
262 void THTensor_(syev)(THTensor *re_, THTensor *rv_, THTensor *a,
const char *jobz,
const char *uplo)
264 if (a == NULL) a = rv_;
265 THArgCheck(a->dim() == 2, 1,
"A should be 2 dimensional");
266 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
268 int n, lda, lwork, info;
269 THTensor *work =
nullptr;
272 THTensor *rv__ = NULL;
273 THTensor *re__ = NULL;
275 rv__ = THTensor_(cloneColumnMajor)(rv_, a);
277 n = THTensor_sizeLegacyNoScalars(rv__, 0);
280 THTensor_(resize1d)(re_,n);
281 re__ = THTensor_(newContiguous)(re_);
285 THLapack_(syev)(jobz[0], uplo[0], n, rv__->data<scalar_t>(), lda,
286 re_->data<scalar_t>(), &wkopt, -1, &info);
288 work = THTensor_(newWithSize1d)(lwork);
289 THLapack_(syev)(jobz[0], uplo[0], n, rv__->data<scalar_t>(), lda,
290 re_->data<scalar_t>(), work->data<scalar_t>(), lwork, &info);
292 THLapackCheckWithCleanup(
"Lapack Error %s : %d off-diagonal elements didn't converge to zero",
293 THCleanup(c10::raw::intrusive_ptr::decref(rv__);
294 c10::raw::intrusive_ptr::decref(re__);
295 c10::raw::intrusive_ptr::decref(work);),
301 THTensor_(fill)(rv_, 0);
304 THTensor_(freeCopyTo)(rv__, rv_);
305 THTensor_(freeCopyTo)(re__, re_);
306 c10::raw::intrusive_ptr::decref(work);
309 void THTensor_(gesdd)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *a,
const char* some,
const char* compute_uv)
311 THTensor *ra_ = THTensor_(
new)();
312 THTensor_(gesdd2)(ru_, rs_, rv_, ra_, a, some, compute_uv);
313 c10::raw::intrusive_ptr::decref(ra_);
316 void THTensor_(gesdd2)(THTensor *ru_, THTensor *rs_, THTensor *rv_, THTensor *ra_, THTensor *a,
317 const char* some,
const char* compute_uv)
319 if (a == NULL) a = ra_;
320 THArgCheck(a->dim() == 2, 1,
"A should be 2 dimensional");
321 THArgCheck(!a->is_empty(), 1,
"A should not be empty");
323 int k, m, n, lda, ldu, ldvt, lwork, info;
328 THTensor *ra__ = NULL;
329 THTensor *ru__ = NULL;
330 THTensor *rs__ = NULL;
331 THTensor *rv__ = NULL;
333 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
343 iwork = k ? THIntTensor_newWithSize1d((int64_t)(8 * m)) : THIntTensor_newWithSize1d((int64_t)(8 * n));
345 THTensor_(resize1d)(rs_,k);
346 THTensor *rvf_ = NULL;
348 if (*compute_uv !=
'N') {
349 rvf_ = THTensor_(
new)();
350 THTensor_(resize2d)(rvf_,ldvt,n);
352 THTensor_(resize2d)(ru_,m,ldu);
354 THTensor_(resize2d)(ru_,k,ldu);
356 THTensor_(resize2d)(rv_,ldvt,n);
357 THTensor_(resize2d)(ru_,m,ldu);
360 THTensor_(checkTransposed)(ru_);
363 scalar_t *rs__data = NULL;
364 scalar_t *ru__data = NULL;
365 scalar_t *rv__data = NULL;
367 rs__ = THTensor_(newContiguous)(rs_);
368 rs__data = rs__->data<scalar_t>();
369 if (*compute_uv !=
'N') {
371 ru__ = THTensor_(newTransposedContiguous)(ru_);
372 rv__ = THTensor_(newContiguous)(rvf_);
374 ru__data = ru__->data<scalar_t>();
375 rv__data = rv__->data<scalar_t>();
380 THLapack_(gesdd)(jobz,
381 m,n,ra__->data<scalar_t>(),lda,
386 &wkopt, -1, THIntTensor_data(iwork), &info);
388 work = THTensor_(newWithSize1d)(lwork);
389 THLapack_(gesdd)(jobz,
390 m,n,ra__->data<scalar_t>(),lda,
395 work->data<scalar_t>(),lwork, THIntTensor_data(iwork), &info);
398 THLapackCheckWithCleanup(
"Lapack Error %s : %d superdiagonals failed to converge.",
400 c10::raw::intrusive_ptr::decref(ru__);
401 c10::raw::intrusive_ptr::decref(rs__);
402 c10::raw::intrusive_ptr::decref(rv__);
403 c10::raw::intrusive_ptr::decref(ra__);
404 c10::raw::intrusive_ptr::decref(work);
405 c10::raw::intrusive_ptr::decref(iwork);),
408 THLapackCheckWithCleanup(
"Lapack Error %s : %d superdiagonals failed to converge.",
410 c10::raw::intrusive_ptr::decref(rs__);
411 c10::raw::intrusive_ptr::decref(ra__);
412 c10::raw::intrusive_ptr::decref(work);
413 c10::raw::intrusive_ptr::decref(iwork);),
417 THTensor_(freeCopyTo)(ra__, ra_);
418 THTensor_(freeCopyTo)(rs__, rs_);
419 c10::raw::intrusive_ptr::decref(work);
420 c10::raw::intrusive_ptr::decref(iwork);
424 THTensor_(narrow)(rv__,NULL,1,0,k);
426 THTensor_(freeCopyTo)(ru__, ru_);
427 THTensor_(freeCopyTo)(rv__, rvf_);
430 THTensor_(narrow)(rvf_,NULL,1,0,k);
432 THTensor_(resizeAs)(rv_, rvf_);
435 at::_copy_same_type_(rv__wrap, rvf__wrap);
436 c10::raw::intrusive_ptr::decref(rvf_);
438 THTensor_(zero)(ru_);
439 THTensor_(zero)(rv_);
443 void THTensor_(getri)(THTensor *ra_, THTensor *a)
445 if (a == NULL) a = ra_;
446 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
447 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
449 int m, n, lda, info, lwork;
453 THTensor *ra__ = NULL;
455 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
460 ipiv = THIntTensor_newWithSize1d((int64_t)m);
463 THLapack_(getrf)(n, n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), &info);
464 THLapackCheckWithCleanup(
"Lapack Error %s : U(%d,%d) is 0, U is singular",
466 c10::raw::intrusive_ptr::decref(ra__);
467 THIntTensor_free(ipiv);),
468 "getrf", info, info);
471 THLapack_(getri)(n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), &wkopt, -1, &info);
473 work = THTensor_(newWithSize1d)(lwork);
474 THLapack_(getri)(n, ra__->data<scalar_t>(), lda, THIntTensor_data(ipiv), work->data<scalar_t>(), lwork, &info);
475 THLapackCheckWithCleanup(
"Lapack Error %s : U(%d,%d) is 0, U is singular",
477 c10::raw::intrusive_ptr::decref(ra__);
478 c10::raw::intrusive_ptr::decref(work);
479 THIntTensor_free(ipiv);),
480 "getri", info, info);
482 THTensor_(freeCopyTo)(ra__, ra_);
483 c10::raw::intrusive_ptr::decref(work);
484 THIntTensor_free(ipiv);
487 void THTensor_(clearUpLoTriangle)(THTensor *a,
const char *uplo)
489 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
490 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
495 scalar_t *p = a->data<scalar_t>();
502 for (i=0; i<n; i++) {
503 for (j=i+1; j<n; j++) {
509 else if (uplo[0] ==
'L')
512 for (i=0; i<n; i++) {
513 for (j=0; j<i; j++) {
520 void THTensor_(copyUpLoTriangle)(THTensor *a,
const char *uplo)
522 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
523 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
528 scalar_t *p = a->data<scalar_t>();
535 for (i=0; i<n; i++) {
536 for (j=i+1; j<n; j++) {
537 p[n*i + j] = p[n*j+i];
542 else if (uplo[0] ==
'L')
545 for (i=0; i<n; i++) {
546 for (j=0; j<i; j++) {
547 p[n*i + j] = p[n*j+i];
553 void THTensor_(potri)(THTensor *ra_, THTensor *a,
const char *uplo)
555 if (a == NULL) a = ra_;
556 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
557 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
560 THTensor *ra__ = NULL;
562 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
564 n = THTensor_sizeLegacyNoScalars(ra__, 0);
568 THLapack_(potri)(uplo[0], n, ra__->data<scalar_t>(), lda, &info);
569 THLapackCheckWithCleanup(
"Lapack Error %s : A(%d,%d) is 0, A cannot be factorized",
570 THCleanup(c10::raw::intrusive_ptr::decref(ra__);),
571 "potri", info, info);
573 THTensor_(copyUpLoTriangle)(ra__, uplo);
574 THTensor_(freeCopyTo)(ra__, ra_);
593 void THTensor_(pstrf)(THTensor *ra_, THIntTensor *rpiv_, THTensor *a,
const char *uplo, scalar_t tol) {
594 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
595 THArgCheck(a->size(0) == a->size(1), 1,
"A should be square");
599 THTensor *ra__ = THTensor_(cloneColumnMajor)(ra_, a);
600 THIntTensor_resize1d(rpiv_, n);
603 THTensor *work = THTensor_(newWithSize1d)(2 * n);
609 THLapack_(pstrf)(uplo[0], n, ra__->data<scalar_t>(), lda,
610 THIntTensor_data(rpiv_), &rank, tol,
611 work->data<scalar_t>(), &info);
613 THLapackCheckWithCleanup(
"Lapack Error %s : matrix is rank deficient or not positive semidefinite",
615 c10::raw::intrusive_ptr::decref(ra__);
616 c10::raw::intrusive_ptr::decref(work);),
619 THTensor_(clearUpLoTriangle)(ra__, uplo);
621 THTensor_(freeCopyTo)(ra__, ra_);
622 c10::raw::intrusive_ptr::decref(work);
640 void THTensor_(qr)(THTensor *rq_, THTensor *rr_, THTensor *a)
644 int k = (m < n ? m : n);
645 THTensor *ra_ = THTensor_(
new)();
646 THTensor *rtau_ = THTensor_(
new)();
647 THTensor *rr__ = THTensor_(
new)();
648 THTensor_(geqrf)(ra_, rtau_, a);
649 THTensor_(resize2d)(rr__, k, ra_->size(1));
650 THTensor_(narrow)(rr__, ra_, 0, 0, k);
651 THTensor_(triu)(rr_, rr__, 0);
652 THTensor_(resize2d)(rq_, ra_->size(0), k);
653 THTensor_(orgqr)(rq_, ra_, rtau_);
654 THTensor_(narrow)(rq_, rq_, 1, 0, k);
655 c10::raw::intrusive_ptr::decref(ra_);
656 c10::raw::intrusive_ptr::decref(rtau_);
657 c10::raw::intrusive_ptr::decref(rr__);
677 void THTensor_(geqrf)(THTensor *ra_, THTensor *rtau_, THTensor *a)
679 if (a == NULL) ra_ = a;
680 THArgCheck(a->dim() == 2, 1,
"A should be 2 dimensional");
681 THArgCheck(!a->is_empty(), 1,
"A should not be empty");
683 THTensor *ra__ = NULL;
686 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
688 int m = ra__->size(0);
689 int n = ra__->size(1);
690 int k = (m < n ? m : n);
692 THTensor_(resize1d)(rtau_, k);
697 THLapack_(geqrf)(m, n, ra__->data<scalar_t>(), lda,
698 rtau_->data<scalar_t>(),
702 int lwork = (int)wkopt;
703 THTensor *work = THTensor_(newWithSize1d)(lwork);
704 THLapack_(geqrf)(m, n, ra__->data<scalar_t>(), lda,
705 rtau_->data<scalar_t>(),
706 work->data<scalar_t>(), lwork, &info);
708 THLapackCheckWithCleanup(
"Lapack Error %s : unknown Lapack error. info = %i",
710 c10::raw::intrusive_ptr::decref(ra__);
711 c10::raw::intrusive_ptr::decref(work);),
714 THTensor_(freeCopyTo)(ra__, ra_);
715 c10::raw::intrusive_ptr::decref(work);
734 void THTensor_(orgqr)(THTensor *ra_, THTensor *a, THTensor *tau)
736 if (a == NULL) a = ra_;
737 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
739 THTensor *ra__ = NULL;
740 ra__ = THTensor_(cloneColumnMajor)(ra_, a);
742 int m = THTensor_sizeLegacyNoScalars(ra__, 0);
743 int k = THTensor_sizeLegacyNoScalars(tau, 0);
749 THLapack_(orgqr)(m, k, k, ra__->data<scalar_t>(), lda,
750 tau->data<scalar_t>(),
754 int lwork = (int)wkopt;
755 THTensor *work = THTensor_(newWithSize1d)(lwork);
756 THLapack_(orgqr)(m, k, k, ra__->data<scalar_t>(), lda,
757 tau->data<scalar_t>(),
758 work->data<scalar_t>(), lwork, &info);
760 THLapackCheckWithCleanup(
" Lapack Error %s : unknown Lapack error. info = %i",
762 c10::raw::intrusive_ptr::decref(ra__);
763 c10::raw::intrusive_ptr::decref(work);),
765 THTensor_(freeCopyTo)(ra__, ra_);
766 c10::raw::intrusive_ptr::decref(work);
787 void THTensor_(ormqr)(THTensor *ra_, THTensor *a, THTensor *tau, THTensor *c,
const char *side,
const char *trans)
789 if (a == NULL) a = ra_;
790 THArgCheck(THTensor_nDimensionLegacyAll(a) == 2, 1,
"A should be 2 dimensional");
792 THTensor *ra__ = NULL;
793 ra__ = THTensor_(cloneColumnMajor)(ra_, c);
797 int k = THTensor_sizeLegacyNoScalars(tau, 0);
812 THLapack_(ormqr)(side[0], trans[0], m, n, k, a->data<scalar_t>(), lda,
813 tau->data<scalar_t>(), ra__->data<scalar_t>(), ldc,
817 int lwork = (int)wkopt;
818 THTensor *work = THTensor_(newWithSize1d)(lwork);
819 THLapack_(ormqr)(side[0], trans[0], m, n, k, a->data<scalar_t>(), lda,
820 tau->data<scalar_t>(), ra__->data<scalar_t>(), ldc,
821 work->data<scalar_t>(), lwork, &info);
823 THLapackCheckWithCleanup(
" Lapack Error %s : unknown Lapack error. info = %i",
825 c10::raw::intrusive_ptr::decref(ra__);
826 c10::raw::intrusive_ptr::decref(work);),
828 THTensor_(freeCopyTo)(ra__, ra_);
829 c10::raw::intrusive_ptr::decref(work);
832 void THTensor_(btrisolve)(THTensor *rb_, THTensor *b, THTensor *atf, THIntTensor *pivots)
834 AT_CHECK(!atf->is_empty() && THTensor_(nDimensionLegacyNoScalars)(atf) == 3,
"expected non-empty 3D tensor, got size: ",
836 AT_CHECK(!b->is_empty() && (THTensor_(nDimensionLegacyNoScalars)(b) == 3 ||
837 THTensor_(nDimensionLegacyNoScalars)(b) == 2),
"expected non-empty 2D or 3D tensor, got size: ", b->sizes());
838 THArgCheck(THTensor_(size)(atf, 0) ==
839 THTensor_(size)(b, 0), 3,
"number of batches must be equal");
840 THArgCheck(THTensor_(size)(atf, 1) ==
841 THTensor_(size)(atf, 2), 3,
"A matrices must be square");
842 THArgCheck(THTensor_(size)(atf, 1) ==
843 THTensor_(size)(b, 1), 3,
"dimensions of A and b must be equal");
846 THTensor_(resizeAs)(rb_, b);
849 at::_copy_same_type_(rb__wrap, b_wrap);
852 int64_t num_batches = atf->size(0);
853 int64_t n = atf->size(1);
854 int nrhs = THTensor_nDimensionLegacyAll(rb_) > 2 ? rb_->size(2) : 1;
861 if (atf->stride(1) == 1) {
863 lda = atf->stride(2);
870 THTensor *transp_r_ = THTensor_(newTranspose)(atf, 1, 2);
871 atf_ = THTensor_(newClone)(transp_r_);
872 c10::raw::intrusive_ptr::decref(transp_r_);
873 THTensor_(transpose)(atf_, NULL, 1, 2);
874 lda = atf_->stride(2);
878 if (rb_->stride(1) == 1) {
880 if (THTensor_nDimensionLegacyAll(rb_) == 2 || rb_->size(2) == 1) {
883 ldb = rb_->stride(2);
888 if (THTensor_nDimensionLegacyAll(rb_) > 2) {
889 THTensor *transp_r_ = THTensor_(newTranspose)(rb_, 1, 2);
890 rb__ = THTensor_(newClone)(transp_r_);
891 c10::raw::intrusive_ptr::decref(transp_r_);
892 THTensor_(transpose)(rb__, NULL, 1, 2);
893 ldb = rb__->stride(2);
895 rb__ = THTensor_(newClone)(rb_);
900 THTensor *ai = THTensor_(
new)();
901 THTensor *rbi = THTensor_(
new)();
902 THIntTensor *pivoti = THIntTensor_new();
904 if (!THIntTensor_isContiguous(pivots)) {
905 THError(
"Error: rpivots_ is not contiguous.");
908 for (int64_t batch = 0; batch < num_batches; ++batch) {
909 THTensor_(select)(ai, atf_, 0, batch);
910 THTensor_(select)(rbi, rb__, 0, batch);
911 THIntTensor_select(pivoti, pivots, 0, batch);
913 #if defined(TH_REAL_IS_FLOAT) || defined(TH_REAL_IS_DOUBLE) 915 THLapack_(getrs)(
'N', n, nrhs, ai->data<scalar_t>(), lda,
916 THIntTensor_data(pivoti), rbi->data<scalar_t>(),
919 THError(
"Error: Nonzero info.");
922 THError(
"Unimplemented");
926 c10::raw::intrusive_ptr::decref(ai);
927 c10::raw::intrusive_ptr::decref(rbi);
928 THIntTensor_free(pivoti);
931 c10::raw::intrusive_ptr::decref(atf_);
935 THTensor_(freeCopyTo)(rb__, rb_);