5 from operator
import mul
6 from functools
import reduce
7 from collections
import Iterable
9 from itertools
import product
35 def broadcast_tensors(*tensors):
36 r"""broadcast_tensors(*tensors) -> List of Tensors 38 Broadcasts the given tensors according to :ref:`broadcasting-semantics`. 41 *tensors: any number of tensors of the same type 45 More than one element of a broadcasted tensor may refer to a single 46 memory location. As a result, in-place operations (especially ones that 47 are vectorized) may result in incorrect behavior. If you need to write 48 to the tensors, please clone them first. 52 >>> x = torch.arange(3).view(1, 3) 53 >>> y = torch.arange(2).view(2, 1) 54 >>> a, b = torch.broadcast_tensors(x, y) 61 return torch._C._VariableFunctions.broadcast_tensors(tensors)
64 def split(tensor, split_size_or_sections, dim=0):
65 r"""Splits the tensor into chunks. 67 If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will 68 be split into equally sized chunks (if possible). Last chunk will be smaller if 69 the tensor size along the given dimension :attr:`dim` is not divisible by 72 If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split 73 into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according 74 to :attr:`split_size_or_sections`. 77 tensor (Tensor): tensor to split. 78 split_size_or_sections (int) or (list(int)): size of a single chunk or 79 list of sizes for each chunk 80 dim (int): dimension along which to split the tensor. 86 return tensor.split(split_size_or_sections, dim)
89 def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
90 r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor. 92 Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``. 95 LU_data (Tensor): the packed LU factorization data 96 LU_pivots (Tensor): the packed LU factorization pivots 97 unpack_data (bool): flag indicating if the data should be unpacked 98 unpack_pivots (bool): flag indicating if the pivots should be unpacked 102 >>> A = torch.randn(2, 3, 3) 103 >>> A_LU, pivots = A.btrifact() 104 >>> P, A_L, A_U = torch.btriunpack(A_LU, pivots) 106 >>> # can recover A from factorization 107 >>> A_ = torch.bmm(P, torch.bmm(A_L, A_U)) 110 sz = LU_data.size(-1)
115 L.diagonal(dim1=-2, dim2=-1).fill_(1)
120 P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
121 LU_pivots = LU_pivots - 1
122 for idx
in product(*map(
lambda x: list(range(x)), LU_data.shape[:-2])):
123 final_order = list(range(sz))
124 for k, j
in enumerate(LU_pivots[idx]):
125 final_order[k], final_order[j] = final_order[j], final_order[k]
126 P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
133 def einsum(equation, *operands):
134 r"""einsum(equation, *operands) -> Tensor 136 This function provides a way of computing multilinear expressions (i.e. sums of products) using the 137 Einstein summation convention. 140 equation (string): The equation is given in terms of lower case letters (indices) to be associated 141 with each dimension of the operands and result. The left hand side lists the operands 142 dimensions, separated by commas. There should be one index letter per tensor dimension. 143 The right hand side follows after `->` and gives the indices for the output. 144 If the `->` and right hand side are omitted, it implicitly defined as the alphabetically 145 sorted list of all indices appearing exactly once in the left hand side. 146 The indices not apprearing in the output are summed over after multiplying the operands 148 If an index appears several times for the same operand, a diagonal is taken. 149 Ellipses `...` represent a fixed number of dimensions. If the right hand side is inferred, 150 the ellipsis dimensions are at the beginning of the output. 151 operands (list of Tensors): The operands to compute the Einstein sum of. 155 >>> x = torch.randn(5) 156 >>> y = torch.randn(4) 157 >>> torch.einsum('i,j->ij', x, y) # outer product 158 tensor([[-0.0570, -0.0286, -0.0231, 0.0197], 159 [ 1.2616, 0.6335, 0.5113, -0.4351], 160 [ 1.4452, 0.7257, 0.5857, -0.4984], 161 [-0.4647, -0.2333, -0.1883, 0.1603], 162 [-1.1130, -0.5588, -0.4510, 0.3838]]) 165 >>> A = torch.randn(3,5,4) 166 >>> l = torch.randn(2,5) 167 >>> r = torch.randn(2,4) 168 >>> torch.einsum('bn,anm,bm->ba', l, A, r) # compare torch.nn.functional.bilinear 169 tensor([[-0.3430, -5.2405, 0.4494], 170 [ 0.3311, 5.5201, -3.0356]]) 173 >>> As = torch.randn(3,2,5) 174 >>> Bs = torch.randn(3,5,4) 175 >>> torch.einsum('bij,bjk->bik', As, Bs) # batch matrix multiplication 176 tensor([[[-1.0564, -1.5904, 3.2023, 3.1271], 177 [-1.6706, -0.8097, -0.8025, -2.1183]], 179 [[ 4.2239, 0.3107, -0.5756, -0.2354], 180 [-1.4558, -0.3460, 1.5087, -0.8530]], 182 [[ 2.8153, 1.8787, -4.3839, -1.2112], 183 [ 0.3728, -2.1131, 0.0921, 0.8305]]]) 185 >>> A = torch.randn(3, 3) 186 >>> torch.einsum('ii->i', A) # diagonal 187 tensor([-0.7825, 0.8291, -0.1936]) 189 >>> A = torch.randn(4, 3, 3) 190 >>> torch.einsum('...ii->...i', A) # batch diagonal 191 tensor([[-1.0864, 0.7292, 0.0569], 192 [-0.9725, -1.0270, 0.6493], 193 [ 0.5832, -1.1716, -1.5084], 194 [ 0.4041, -1.1690, 0.8570]]) 196 >>> A = torch.randn(2, 3, 4, 5) 197 >>> torch.einsum('...ij->...ji', A).shape # batch permute 198 torch.Size([2, 3, 5, 4]) 200 if len(operands) == 1
and isinstance(operands[0], (list, tuple)):
202 operands = operands[0]
203 return torch._C._VariableFunctions.einsum(equation, operands)
206 def isfinite(tensor):
207 r"""Returns a new tensor with boolean elements representing if each element is `Finite` or not. 210 tensor (Tensor): A tensor to check 213 Tensor: A ``torch.ByteTensor`` containing a 1 at each location of finite elements and 0 otherwise 217 >>> torch.isfinite(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) 218 tensor([ 1, 0, 1, 0, 0], dtype=torch.uint8) 220 if not isinstance(tensor, torch.Tensor):
221 raise ValueError(
"The argument is not a tensor", str(tensor))
227 if not tensor.is_floating_point():
228 return torch.ones_like(tensor, dtype=torch.uint8)
229 return (tensor == tensor) & (tensor.abs() != inf)
233 r"""Returns a new tensor with boolean elements representing if each element is `+/-INF` or not. 236 tensor (Tensor): A tensor to check 239 Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `+/-INF` elements and 0 otherwise 243 >>> torch.isinf(torch.tensor([1, float('inf'), 2, float('-inf'), float('nan')])) 244 tensor([ 0, 1, 0, 1, 0], dtype=torch.uint8) 246 if not isinstance(tensor, torch.Tensor):
247 raise ValueError(
"The argument is not a tensor", str(tensor))
248 if tensor.dtype
in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
249 return torch.zeros_like(tensor, dtype=torch.uint8)
250 return tensor.abs() == inf
253 def meshgrid(*tensors, **kwargs):
254 r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional 255 vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by 256 expanding the :math:`i` :sup:`th` input over dimensions defined by other inputs. 260 tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be 261 treated as tensors of size :math:`(1,)` automatically 264 seq (sequence of Tensors): If the input has :math:`k` tensors of size 265 :math:`(N_1,), (N_2,), \ldots , (N_k,)`, then the output would also has :math:`k` tensors, 266 where all tensors are of size :math:`(N_1, N_2, \ldots , N_k)`. 270 >>> x = torch.tensor([1, 2, 3]) 271 >>> y = torch.tensor([4, 5, 6]) 272 >>> grid_x, grid_y = torch.meshgrid(x, y) 283 raise TypeError(
"meshgrid() got an unexpected keyword argument '%s'" % (list(kwargs)[0],))
284 if len(tensors) == 1
and isinstance(tensors[0], (list, tuple)):
287 return torch._C._VariableFunctions.meshgrid(tensors)
290 def stft(input, n_fft, hop_length=None, win_length=None, window=None,
291 center=
True, pad_mode=
'reflect', normalized=
False, onesided=
True):
292 r"""Short-time Fourier transform (STFT). 294 Ignoring the optional batch dimension, this method computes the following 298 X[m, \omega] = \sum_{k = 0}^{\text{win\_length-1}}% 299 \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ % 300 \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{win\_length}}\right), 302 where :math:`m` is the index of the sliding window, and :math:`\omega` is 303 the frequency that :math:`0 \leq \omega < \text{n\_fft}`. When 304 :attr:`onesided` is the default value ``True``, 306 * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time 309 * If :attr:`hop_length` is ``None`` (default), it is treated as equal to 310 ``floor(n_fft / 4)``. 312 * If :attr:`win_length` is ``None`` (default), it is treated as equal to 315 * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from 316 :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is 317 treated as if having :math:`1` everywhere in the window. If 318 :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on 319 both sides to length :attr:`n_fft` before being applied. 321 * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on 322 both sides so that the :math:`t`-th frame is centered at time 323 :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame 324 begins at time :math:`t \times \text{hop\_length}`. 326 * :attr:`pad_mode` determines the padding method used on :attr:`input` when 327 :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for 328 all available options. Default is ``"reflect"``. 330 * If :attr:`onesided` is ``True`` (default), only values for :math:`\omega` 331 in :math:`\left[0, 1, 2, \dots, \left\lfloor \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` 332 are returned because the real-to-complex Fourier transform satisfies the 333 conjugate symmetry, i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`. 335 * If :attr:`normalized` is ``True`` (default is ``False``), the function 336 returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`. 338 Returns the real and the imaginary parts together as one tensor of size 339 :math:`(* \times N \times T \times 2)`, where :math:`*` is the optional 340 batch size of :attr:`input`, :math:`N` is the number of frequencies where 341 STFT is applied, :math:`T` is the total number of frames used, and each pair 342 in the last dimension represents a complex number as the real part and the 346 This function changed signature at version 0.4.1. Calling with the 347 previous signature may cause error or return incorrect result. 350 input (Tensor): the input tensor 351 n_fft (int): size of Fourier transform 352 hop_length (int, optional): the distance between neighboring sliding window 353 frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``) 354 win_length (int, optional): the size of window frame and STFT filter. 355 Default: ``None`` (treated as equal to :attr:`n_fft`) 356 window (Tensor, optional): the optional window function. 357 Default: ``None`` (treated as window of all :math:`1` s) 358 center (bool, optional): whether to pad :attr:`input` on both sides so 359 that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`. 361 pad_mode (string, optional): controls the padding method used when 362 :attr:`center` is ``True``. Default: ``"reflect"`` 363 normalized (bool, optional): controls whether to return the normalized STFT results 365 onesided (bool, optional): controls whether to return half of results to 366 avoid redundancy Default: ``True`` 369 Tensor: A tensor containing the STFT result with shape described above 375 signal_dim = input.dim()
376 extended_shape = [1] * (3 - signal_dim) + list(input.size())
377 pad = int(n_fft // 2)
378 input = F.pad(input.view(extended_shape), (pad, pad), pad_mode)
379 input = input.view(input.shape[-signal_dim:])
380 return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
383 isnan = _add_docstr(torch.isnan,
r""" 384 Returns a new tensor with boolean elements representing if each element is `NaN` or not. 387 tensor (Tensor): A tensor to check 390 Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements. 394 >>> torch.isnan(torch.tensor([1, float('nan'), 2])) 395 tensor([ 0, 1, 0], dtype=torch.uint8) 399 def unique(input, sorted=True, return_inverse=False, dim=None):
400 r"""Returns the unique scalar elements of the input tensor as a 1-D tensor. 403 input (Tensor): the input tensor 404 sorted (bool): Whether to sort the unique elements in ascending order 405 before returning as output. 406 return_inverse (bool): Whether to also return the indices for where 407 elements in the original input ended up in the returned unique list. 408 dim (int): the dimension to apply unique. If ``None``, the unique of the 409 flattened input is returned. default: ``None`` 412 (Tensor, Tensor (optional)): A tensor or a tuple of tensors containing 414 - **output** (*Tensor*): the output list of unique scalar elements. 415 - **inverse_indices** (*Tensor*): (optional) if 416 :attr:`return_inverse` is True, there will be a 417 2nd returned tensor (same shape as input) representing the indices 418 for where elements in the original input map to in the output; 419 otherwise, this function will only return a single tensor. 423 >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long)) 427 >>> output, inverse_indices = torch.unique( 428 torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True) 432 tensor([ 0, 2, 1, 2]) 434 >>> output, inverse_indices = torch.unique( 435 torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True) 444 output, inverse_indices = torch._unique_dim(
448 return_inverse=return_inverse
451 output, inverse_indices = torch._unique(
454 return_inverse=return_inverse,
457 return output, inverse_indices
462 def tensordot(a, b, dims=2):
463 r"""Returns a contraction of a and b over multiple dimensions. 465 :attr:`tensordot` implements a generalizes the matrix product. 468 a (Tensor): Left tensor to contract 469 b (Tensor): Right tensor to contract 470 dims (int or tuple of two lists of integers): number of dimensions to 471 contract or explicit lists of dimensions for :attr:`a` and 472 :attr:`b` respectively 474 When called with an integer argument :attr:`dims` = :math:`d`, and the number of 475 dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively, 479 r_{i_0,...,i_{m-d}, i_d,...,i_n} 480 = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}. 482 When called with :attr:`dims` of the list form, the given dimensions will be contracted 483 in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes 484 in these dimensions must match, but :attr:`tensordot` will deal with broadcasted 489 >>> a = torch.arange(60.).reshape(3, 4, 5) 490 >>> b = torch.arange(24.).reshape(4, 3, 2) 491 >>> torch.tensordot(a, b, dims=([1, 0], [0, 1])) 492 tensor([[4400., 4730.], 498 >>> a = torch.randn(3, 4, 5, device='cuda') 499 >>> b = torch.randn(4, 5, 6, device='cuda') 500 >>> c = torch.tensordot(a, b, dims=2).cpu() 501 tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741], 502 [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744], 503 [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) 506 if isinstance(dims, (list, tuple))
or \
507 (isinstance(dims, torch.Tensor)
and dims.numel() > 1):
508 dims_a, dims_b = dims
510 if isinstance(dims, torch.Tensor):
512 dims_a = list(range(-dims, 0))
513 dims_b = list(range(dims))
514 return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
517 def cartesian_prod(*tensors):
518 """Do cartesian product of the given sequence of tensors. The behavior is similar to 519 python's `itertools.product`. 522 *tensors: any number of 1 dimensional tensors. 525 Tensor: A tensor equivalent to converting all the input tensors into lists, 526 do `itertools.product` on these lists, and finally convert the resulting list 533 >>> list(itertools.product(a, b)) 534 [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)] 535 >>> tensor_a = torch.tensor(a) 536 >>> tensor_b = torch.tensor(b) 537 >>> torch.cartesian_prod(tensor_a, tensor_b) 545 return torch._C._VariableFunctions.cartesian_prod(tensors)
548 def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None):
549 r"""Returns the matrix norm or vector norm of a given tensor. 552 input (Tensor): the input tensor 553 p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'`` 554 The following norms can be calculated: 556 ===== ============================ ========================== 557 ord matrix norm vector norm 558 ===== ============================ ========================== 559 None Frobenius norm 2-norm 560 'fro' Frobenius norm -- 561 'nuc' nuclear norm -- 562 Other as vec norm when dim is None sum(abs(x)**ord)**(1./ord) 563 ===== ============================ ========================== 565 dim (int, 2-tuple of ints, 2-list of ints, optional): If it is an int, 566 vector norm will be calculated, if it is 2-tuple of ints, matrix norm 567 will be calculated. If the value is None, matrix norm will be calculated 568 when the input tensor only has two dimensions, vector norm will be 569 calculated when the input tensor only has one dimension. If the input 570 tensor has more than two dimensions, the vector norm will be applied to 572 keepdim (bool, optional): whether the output tensors have :attr:`dim` 573 retained or not. Ignored if :attr:`dim` = ``None`` and 574 :attr:`out` = ``None``. Default: ``False`` 575 out (Tensor, optional): the output tensor. Ignored if 576 :attr:`dim` = ``None`` and :attr:`out` = ``None``. 577 dtype (:class:`torch.dtype`, optional): the desired data type of 578 returned tensor. If specified, the input tensor is casted to 579 :attr:'dtype' while performing the operation. Default: None. 585 >>> a = torch.arange(9, dtype= torch.float) - 4 586 >>> b = a.reshape((3, 3)) 591 >>> torch.norm(a, float('inf')) 593 >>> torch.norm(b, float('inf')) 595 >>> c = torch.tensor([[ 1, 2, 3],[-1, 1, 4]] , dtype= torch.float) 596 >>> torch.norm(c, dim=0) 597 tensor([1.4142, 2.2361, 5.0000]) 598 >>> torch.norm(c, dim=1) 599 tensor([3.7417, 4.2426]) 600 >>> torch.norm(c, p=1, dim=1) 602 >>> d = torch.arange(8, dtype= torch.float).reshape(2,2,2) 603 >>> torch.norm(d, dim=(1,2)) 604 tensor([ 3.7417, 11.2250]) 605 >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) 606 (tensor(3.7417), tensor(11.2250)) 611 if dim
is None and out
is None and dtype
is None:
613 return torch._C._VariableFunctions.frobenius_norm(input)
615 return torch._C._VariableFunctions.norm(input, p)
618 if dtype
is not None:
619 raise ValueError(
"dtype argument is not supported in frobenius norm")
621 dim = tuple(range(ndim))
623 return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim)
624 return torch._C._VariableFunctions.frobenius_norm(input, dim, keepdim=keepdim, out=out)
626 if dtype
is not None:
627 raise ValueError(
"dtype argument is not supported in nuclear norm")
629 torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim)
630 return torch._C._VariableFunctions.nuclear_norm(input, keepdim=keepdim, out=out)
633 dim = tuple(range(ndim))
634 if out
is None and dtype
is None:
635 return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim)
637 return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype)
639 return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, out=out)
640 return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out)
643 def chain_matmul(*matrices):
644 r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed 645 using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms 646 of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N` 647 needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned. 648 If :math:`N` is 1, then this is a no-op - the original matrix is returned as is. 652 matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined. 656 Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product 657 would be of dimensions :math:`p_{1} \times p_{N + 1}`. 661 >>> a = torch.randn(3, 4) 662 >>> b = torch.randn(4, 5) 663 >>> c = torch.randn(5, 6) 664 >>> d = torch.randn(6, 7) 665 >>> torch.chain_matmul(a, b, c, d) 666 tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614], 667 [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163], 668 [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]]) 670 .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition 672 return torch._C._VariableFunctions.chain_matmul(matrices)
675 def potrf(a, upper=True, out=None):
676 r"""Computes the Cholesky decomposition of a symmetric positive-definite 679 For more information regarding :func:`torch.potrf`, please check :func:`torch.cholesky`. 682 :func:`torch.potrf` is deprecated in favour of :func:`torch.cholesky` and will be removed 683 in the next release. Please use :func:`torch.cholesky` instead and note that the :attr:`upper` 684 argument in :func:`torch.cholesky` defaults to ``False``. 686 warnings.warn(
"torch.potrf is deprecated in favour of torch.cholesky and will be removed in the next " 687 "release. Please use torch.cholesky instead and note that the :attr:`upper` argument in" 688 " torch.cholesky defaults to ``False``.", stacklevel=2)
689 return torch.cholesky(a, upper=upper, out=out)
692 def pstrf(a, upper=True, out=None):
693 r"""Computes the pivoted Cholesky decomposition of a symmetric positive-definite 694 matrix :attr:`a`. returns a namedtuple (u, pivot) of matrice. 696 If :attr:`upper` is ``True`` or not provided, `u` is upper triangular 697 such that :math:`a = p^T u^T u p`, with `p` the permutation given by `pivot`. 699 If :attr:`upper` is ``False``, `u` is lower triangular such that 700 :math:`a = p^T u u^T p`. 703 :func:`torch.pstrf` is deprecated in favour of :func:`torch.cholesky` and will 704 be removed in the next release. 707 a (Tensor): the input 2-D tensor 708 upper (bool, optional): whether to return a upper (default) or lower triangular matrix 709 out (tuple, optional): namedtuple of `u` and `pivot` tensors 713 >>> a = torch.randn(3, 3) 714 >>> a = torch.mm(a, a.t()) # make symmetric positive definite 716 tensor([[ 3.5405, -0.4577, 0.8342], 717 [-0.4577, 1.8244, -0.1996], 718 [ 0.8342, -0.1996, 3.7493]]) 719 >>> u,piv = torch.pstrf(a) 721 tensor([[ 1.9363, 0.4308, -0.1031], 722 [ 0.0000, 1.8316, -0.2256], 723 [ 0.0000, 0.0000, 1.3277]]) 725 tensor([ 2, 0, 1], dtype=torch.int32) 726 >>> p = torch.eye(3).index_select(0,piv.long()).index_select(0,piv.long()).t() # make pivot permutation 727 >>> torch.mm(torch.mm(p.t(),torch.mm(u.t(),u)),p) # reconstruct 728 tensor([[ 3.5405, -0.4577, 0.8342], 729 [-0.4577, 1.8244, -0.1996], 730 [ 0.8342, -0.1996, 3.7493]]) 732 warnings.warn(
"torch.pstrf is deprecated in favour of torch.cholesky and will be removed " 733 "in the next release.", stacklevel=2)
734 return torch._C._VariableFunctions.pstrf(a, upper=upper, out=out)
737 def potrs(b, u, upper=True, out=None):
738 r"""Solves a linear system of equations with a positive semidefinite 739 matrix to be inverted given its Cholesky factor matrix :attr:`u`. 741 For more information regarding :func:`torch.potrs`, please check :func:`torch.cholesky_solve`. 744 :func:`torch.potrs` is deprecated in favour of :func:`torch.cholesky_solve` and will be 745 removed in the next release. Please use :func:`torch.cholesky_solve` instead and note that 746 the :attr:`upper` argument in :func:`torch.cholesky_solve` defaults to ``False``. 748 warnings.warn(
"torch.potrs is deprecated in favour of torch.cholesky_solve and will be removed " 749 "in the next release. Please use torch.cholesky instead and note that the " 750 ":attr:`upper` argument in torch.cholesky_solve defaults to ``False``.", stacklevel=2)
751 return torch.cholesky_solve(b, u, upper=upper, out=out)
754 def gesv(b, A, out=None):
755 r"""This function returns the solution to the system of linear equations represented 756 by :math:`AX = B` and the LU factorization of A, in order as a tuple `X, LU`. 758 For more information regarding :func:`torch.gesv`, please check :func:`torch.solve`. 761 :func:`torch.gesv` is deprecated in favour of :func:`torch.solve` and will be removed in the 762 next release. Please use :func:`torch.solve` instead. 764 warnings.warn(
"torch.gesv is deprecated in favour of torch.solve and will be removed in the " 765 "next release. Please use torch.solve instead.", stacklevel=2)
766 return torch.solve(b, A, out=out)
Module caffe2.python.layers.split.