|  | # -*- coding: utf-8 -*- | 
|  | import sys | 
|  |  | 
|  | import torch | 
|  | from torch._C import _add_docstr, _linalg  # type: ignore[attr-defined] | 
|  |  | 
|  | LinAlgError = torch._C._LinAlgError  # type: ignore[attr-defined] | 
|  |  | 
|  | Tensor = torch.Tensor | 
|  |  | 
|  | common_notes = { | 
|  | "experimental_warning": """This function is "experimental" and it may change in a future PyTorch release.""", | 
|  | "sync_note": "When inputs are on a CUDA device, this function synchronizes that device with the CPU.", | 
|  | "sync_note_ex": r"When the inputs are on a CUDA device, this function synchronizes only when :attr:`check_errors`\ `= True`.", | 
|  | "sync_note_has_ex": ("When inputs are on a CUDA device, this function synchronizes that device with the CPU. " | 
|  | "For a version of this function that does not synchronize, see :func:`{}`.") | 
|  | } | 
|  |  | 
|  |  | 
|  | # Note: This not only adds doc strings for functions in the linalg namespace, but | 
|  | # also connects the torch.linalg Python namespace to the torch._C._linalg builtins. | 
|  |  | 
|  | cross = _add_docstr(_linalg.linalg_cross, r""" | 
|  | linalg.cross(input, other, *, dim=-1, out=None) -> Tensor | 
|  |  | 
|  |  | 
|  | Computes the cross product of two 3-dimensional vectors. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. Also supports batches | 
|  | of vectors, for which it computes the product along the dimension :attr:`dim`. | 
|  | It broadcasts over the batch dimensions. | 
|  |  | 
|  | Args: | 
|  | input (Tensor): the first input tensor. | 
|  | other (Tensor): the second input tensor. | 
|  | dim  (int, optional): the dimension along which to take the cross-product. Default: `-1`. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): the output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Example: | 
|  | >>> a = torch.randn(4, 3) | 
|  | >>> a | 
|  | tensor([[-0.3956,  1.1455,  1.6895], | 
|  | [-0.5849,  1.3672,  0.3599], | 
|  | [-1.1626,  0.7180, -0.0521], | 
|  | [-0.1339,  0.9902, -2.0225]]) | 
|  | >>> b = torch.randn(4, 3) | 
|  | >>> b | 
|  | tensor([[-0.0257, -1.4725, -1.2251], | 
|  | [-1.1479, -0.7005, -1.9757], | 
|  | [-1.3904,  0.3726, -1.1836], | 
|  | [-0.9688, -0.7153,  0.2159]]) | 
|  | >>> torch.linalg.cross(a, b) | 
|  | tensor([[ 1.0844, -0.5281,  0.6120], | 
|  | [-2.4490, -1.5687,  1.9792], | 
|  | [-0.8304, -1.3037,  0.5650], | 
|  | [-1.2329,  1.9883,  1.0551]]) | 
|  | >>> a = torch.randn(1, 3)  # a is broadcast to match shape of b | 
|  | >>> a | 
|  | tensor([[-0.9941, -0.5132,  0.5681]]) | 
|  | >>> torch.linalg.cross(a, b) | 
|  | tensor([[ 1.4653, -1.2325,  1.4507], | 
|  | [ 1.4119, -2.6163,  0.1073], | 
|  | [ 0.3957, -1.9666, -1.0840], | 
|  | [ 0.2956, -0.3357,  0.2139]]) | 
|  | """) | 
|  |  | 
|  | cholesky = _add_docstr(_linalg.linalg_cholesky, r""" | 
|  | linalg.cholesky(A, *, upper=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **Cholesky decomposition** of a complex Hermitian or real symmetric positive-definite matrix | 
|  | :math:`A \in \mathbb{K}^{n \times n}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = LL^{\text{H}}\mathrlap{\qquad L \in \mathbb{K}^{n \times n}} | 
|  |  | 
|  | where :math:`L` is a lower triangular matrix with real positive diagonal (even in the complex case) and | 
|  | :math:`L^{\text{H}}` is the conjugate transpose when :math:`L` is complex, and the transpose when :math:`L` is real-valued. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.cholesky_ex` for a version of this operation that | 
|  | skips the (slow) error checking by default and instead returns the debug | 
|  | information. This makes it a faster way to check if a matrix is | 
|  | positive-definite. | 
|  |  | 
|  | :func:`torch.linalg.eigh` for a different decomposition of a Hermitian matrix. | 
|  | The eigenvalue decomposition gives more information about the matrix but it | 
|  | slower to compute than the Cholesky decomposition. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of symmetric or Hermitian positive-definite matrices. | 
|  |  | 
|  | Keyword args: | 
|  | upper (bool, optional): whether to return an upper triangular matrix. | 
|  | The tensor returned with upper=True is the conjugate transpose of the tensor | 
|  | returned with upper=False. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the :attr:`A` matrix or any matrix in a batched :attr:`A` is not Hermitian | 
|  | (resp. symmetric) positive-definite. If :attr:`A` is a batch of matrices, | 
|  | the error message will include the batch index of the first matrix that fails | 
|  | to meet this condition. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> A = A @ A.T.conj() + torch.eye(2) # creates a Hermitian positive-definite matrix | 
|  | >>> A | 
|  | tensor([[2.5266+0.0000j, 1.9586-2.0626j], | 
|  | [1.9586+2.0626j, 9.4160+0.0000j]], dtype=torch.complex128) | 
|  | >>> L = torch.linalg.cholesky(A) | 
|  | >>> L | 
|  | tensor([[1.5895+0.0000j, 0.0000+0.0000j], | 
|  | [1.2322+1.2976j, 2.4928+0.0000j]], dtype=torch.complex128) | 
|  | >>> torch.dist(L @ L.T.conj(), A) | 
|  | tensor(4.4692e-16, dtype=torch.float64) | 
|  |  | 
|  | >>> A = torch.randn(3, 2, 2, dtype=torch.float64) | 
|  | >>> A = A @ A.mT + torch.eye(2)  # batch of symmetric positive-definite matrices | 
|  | >>> L = torch.linalg.cholesky(A) | 
|  | >>> torch.dist(L @ L.mT, A) | 
|  | tensor(5.8747e-16, dtype=torch.float64) | 
|  | """) | 
|  |  | 
|  | cholesky_ex = _add_docstr(_linalg.linalg_cholesky_ex, r""" | 
|  | linalg.cholesky_ex(A, *, upper=False, check_errors=False, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the Cholesky decomposition of a complex Hermitian or real | 
|  | symmetric positive-definite matrix. | 
|  |  | 
|  | This function skips the (slow) error checking and error message construction | 
|  | of :func:`torch.linalg.cholesky`, instead directly returning the LAPACK | 
|  | error codes as part of a named tuple ``(L, info)``. This makes this function | 
|  | a faster way to check if a matrix is positive-definite, and it provides an | 
|  | opportunity to handle decomposition errors more gracefully or performantly | 
|  | than :func:`torch.linalg.cholesky` does. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | If :attr:`A` is not a Hermitian positive-definite matrix, or if it's a batch of matrices | 
|  | and one or more of them is not a Hermitian positive-definite matrix, | 
|  | then ``info`` stores a positive integer for the corresponding matrix. | 
|  | The positive integer indicates the order of the leading minor that is not positive-definite, | 
|  | and the decomposition could not be completed. | 
|  | ``info`` filled with zeros indicates that the decomposition was successful. | 
|  | If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_ex"]} | 
|  |  | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  | :func:`torch.linalg.cholesky` is a NumPy compatible variant that always checks for errors. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): the Hermitian `n \times n` matrix or the batch of such matrices of size | 
|  | `(*, n, n)` where `*` is one or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | upper (bool, optional): whether to return an upper triangular matrix. | 
|  | The tensor returned with upper=True is the conjugate transpose of the tensor | 
|  | returned with upper=False. | 
|  | check_errors (bool, optional): controls whether to check the content of ``infos``. Default: `False`. | 
|  | out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> A = A @ A.t().conj()  # creates a Hermitian positive-definite matrix | 
|  | >>> L, info = torch.linalg.cholesky_ex(A) | 
|  | >>> A | 
|  | tensor([[ 2.3792+0.0000j, -0.9023+0.9831j], | 
|  | [-0.9023-0.9831j,  0.8757+0.0000j]], dtype=torch.complex128) | 
|  | >>> L | 
|  | tensor([[ 1.5425+0.0000j,  0.0000+0.0000j], | 
|  | [-0.5850-0.6374j,  0.3567+0.0000j]], dtype=torch.complex128) | 
|  | >>> info | 
|  | tensor(0, dtype=torch.int32) | 
|  |  | 
|  | """) | 
|  |  | 
|  | inv = _add_docstr(_linalg.linalg_inv, r""" | 
|  | linalg.inv(A, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the inverse of a square matrix if it exists. | 
|  | Throws a `RuntimeError` if the matrix is not invertible. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | for a matrix :math:`A \in \mathbb{K}^{n \times n}`, | 
|  | its **inverse matrix** :math:`A^{-1} \in \mathbb{K}^{n \times n}` (if it exists) is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A^{-1}A = AA^{-1} = \mathrm{I}_n | 
|  |  | 
|  | where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. | 
|  |  | 
|  | The inverse matrix exists if and only if :math:`A` is `invertible`_. In this case, | 
|  | the inverse is unique. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices | 
|  | then the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. note:: | 
|  | Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by | 
|  | the inverse, as:: | 
|  |  | 
|  | linalg.solve(A, B) == linalg.inv(A) @ B  # When B is a matrix | 
|  |  | 
|  | It is always prefered to use :func:`~solve` when possible, as it is faster and more | 
|  | numerically stable than computing the inverse explicitly. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.pinv` computes the pseudoinverse (Moore-Penrose inverse) of matrices | 
|  | of any shape. | 
|  |  | 
|  | :func:`torch.linalg.solve` computes :attr:`A`\ `.inv() @ \ `:attr:`B` with a | 
|  | numerically stable algorithm. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of invertible matrices. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the matrix :attr:`A` or any matrix in the batch of matrices :attr:`A` is not invertible. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(4, 4) | 
|  | >>> Ainv = torch.linalg.inv(A) | 
|  | >>> torch.dist(A @ Ainv, torch.eye(4)) | 
|  | tensor(1.1921e-07) | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 4, 4)  # Batch of matrices | 
|  | >>> Ainv = torch.linalg.inv(A) | 
|  | >>> torch.dist(A @ Ainv, torch.eye(4)) | 
|  | tensor(1.9073e-06) | 
|  |  | 
|  | >>> A = torch.randn(4, 4, dtype=torch.complex128)  # Complex matrix | 
|  | >>> Ainv = torch.linalg.inv(A) | 
|  | >>> torch.dist(A @ Ainv, torch.eye(4)) | 
|  | tensor(7.5107e-16, dtype=torch.float64) | 
|  |  | 
|  | .. _invertible: | 
|  | https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem | 
|  | """) | 
|  |  | 
|  | solve_ex = _add_docstr(_linalg.linalg_solve_ex, r""" | 
|  | linalg.solve_ex(A, B, *, left=True, check_errors=False, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | A version of :func:`~solve` that does not perform error checks unless :attr:`check_errors`\ `= True`. | 
|  | It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_ex"]} | 
|  |  | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. | 
|  | check_errors (bool, optional): controls whether to check the content of ``infos`` and raise | 
|  | an error if it is non-zero. Default: `False`. | 
|  | out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(result, info)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> Ainv, info = torch.linalg.solve_ex(A) | 
|  | >>> torch.dist(torch.linalg.inv(A), Ainv) | 
|  | tensor(0.) | 
|  | >>> info | 
|  | tensor(0, dtype=torch.int32) | 
|  |  | 
|  | .. _LAPACK's getrf: | 
|  | https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html | 
|  | """) | 
|  |  | 
|  | inv_ex = _add_docstr(_linalg.linalg_inv_ex, r""" | 
|  | linalg.inv_ex(A, *, check_errors=False, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the inverse of a square matrix if it is invertible. | 
|  |  | 
|  | Returns a namedtuple ``(inverse, info)``. ``inverse`` contains the result of | 
|  | inverting :attr:`A` and ``info`` stores the LAPACK error codes. | 
|  |  | 
|  | If :attr:`A` is not an invertible matrix, or if it's a batch of matrices | 
|  | and one or more of them is not an invertible matrix, | 
|  | then ``info`` stores a positive integer for the corresponding matrix. | 
|  | The positive integer indicates the diagonal element of the LU decomposition of | 
|  | the input matrix that is exactly zero. | 
|  | ``info`` filled with zeros indicates that the inversion was successful. | 
|  | If ``check_errors=True`` and ``info`` contains positive integers, then a RuntimeError is thrown. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_ex"]} | 
|  |  | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.inv` is a NumPy compatible variant that always checks for errors. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of square matrices. | 
|  | check_errors (bool, optional): controls whether to check the content of ``info``. Default: `False`. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> Ainv, info = torch.linalg.inv_ex(A) | 
|  | >>> torch.dist(torch.linalg.inv(A), Ainv) | 
|  | tensor(0.) | 
|  | >>> info | 
|  | tensor(0, dtype=torch.int32) | 
|  |  | 
|  | """) | 
|  |  | 
|  | det = _add_docstr(_linalg.linalg_det, r""" | 
|  | linalg.det(A, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the determinant of a square matrix. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.slogdet` computes the sign (resp. angle) and natural logarithm of the | 
|  | absolute value of the determinant of real-valued (resp. complex) square matrices. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> torch.linalg.det(A) | 
|  | tensor(0.0934) | 
|  |  | 
|  | >>> A = torch.randn(3, 2, 2) | 
|  | >>> torch.linalg.det(A) | 
|  | tensor([1.1990, 0.4099, 0.7386]) | 
|  | """) | 
|  |  | 
|  | slogdet = _add_docstr(_linalg.linalg_slogdet, r""" | 
|  | linalg.slogdet(A, *, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the sign and natural logarithm of the absolute value of the determinant of a square matrix. | 
|  |  | 
|  | For complex :attr:`A`, it returns the angle and the natural logarithm of the modulus of the | 
|  | determinant, that is, a logarithmic polar decomposition of the determinant. | 
|  |  | 
|  | The determinant can be recovered as `sign * exp(logabsdet)`. | 
|  | When a matrix has a determinant of zero, it returns `(0, -inf)`. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.det` computes the determinant of square matrices. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(sign, logabsdet)`. | 
|  |  | 
|  | `sign` will have the same dtype as :attr:`A`. | 
|  |  | 
|  | `logabsdet` will always be real-valued, even when :attr:`A` is complex. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> A | 
|  | tensor([[ 0.0032, -0.2239, -1.1219], | 
|  | [-0.6690,  0.1161,  0.4053], | 
|  | [-1.6218, -0.9273, -0.0082]]) | 
|  | >>> torch.linalg.det(A) | 
|  | tensor(-0.7576) | 
|  | >>> torch.logdet(A) | 
|  | tensor(nan) | 
|  | >>> torch.linalg.slogdet(A) | 
|  | torch.return_types.linalg_slogdet(sign=tensor(-1.), logabsdet=tensor(-0.2776)) | 
|  | """) | 
|  |  | 
|  | eig = _add_docstr(_linalg.linalg_eig, r""" | 
|  | linalg.eig(A, *, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the eigenvalue decomposition of a square matrix if it exists. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **eigenvalue decomposition** of a square matrix | 
|  | :math:`A \in \mathbb{K}^{n \times n}` (if it exists) is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = V \operatorname{diag}(\Lambda) V^{-1}\mathrlap{\qquad V \in \mathbb{C}^{n \times n}, \Lambda \in \mathbb{C}^n} | 
|  |  | 
|  | This decomposition exists if and only if :math:`A` is `diagonalizable`_. | 
|  | This is the case when all its eigenvalues are different. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. note:: The eigenvalues and eigenvectors of a real matrix may be complex. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. warning:: This function assumes that :attr:`A` is `diagonalizable`_ (for example, when all the | 
|  | eigenvalues are different). If it is not diagonalizable, the returned | 
|  | eigenvalues will be correct but :math:`A \neq V \operatorname{diag}(\Lambda)V^{-1}`. | 
|  |  | 
|  | .. warning:: The returned eigenvectors are normalized to have norm `1`. | 
|  | Even then, the eigenvectors of a matrix are not unique, nor are they continuous with respect to | 
|  | :attr:`A`. Due to this lack of uniqueness, different hardware and software may compute | 
|  | different eigenvectors. | 
|  |  | 
|  | This non-uniqueness is caused by the fact that multiplying an eigenvector by | 
|  | by :math:`e^{i \phi}, \phi \in \mathbb{R}` produces another set of valid eigenvectors | 
|  | of the matrix.  For this reason, the loss function shall not depend on the phase of the | 
|  | eigenvectors, as this quantity is not well-defined. | 
|  | This is checked when computing the gradients of this function. As such, | 
|  | when inputs are on a CUDA device, this function synchronizes that device with the CPU | 
|  | when computing the gradients. | 
|  | This is checked when computing the gradients of this function. As such, | 
|  | when inputs are on a CUDA device, the computation of the gradients | 
|  | of this function synchronizes that device with the CPU. | 
|  |  | 
|  |  | 
|  | .. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when | 
|  | :attr:`A` has distinct eigenvalues. | 
|  | Furthermore, if the distance between any two eigenvalues is close to zero, | 
|  | the gradient will be numerically unstable, as it depends on the eigenvalues | 
|  | :math:`\lambda_i` through the computation of | 
|  | :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.eigvals` computes only the eigenvalues. | 
|  | Unlike :func:`torch.linalg.eig`, the gradients of :func:`~eigvals` are always | 
|  | numerically stable. | 
|  |  | 
|  | :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition | 
|  | for Hermitian and symmetric matrices. | 
|  |  | 
|  | :func:`torch.linalg.svd` for a function that computes another type of spectral | 
|  | decomposition that works on matrices of any shape. | 
|  |  | 
|  | :func:`torch.linalg.qr` for another (much faster) decomposition that works on matrices of | 
|  | any shape. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of diagonalizable matrices. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`V` above. | 
|  |  | 
|  | `eigenvalues` and `eigenvectors` will always be complex-valued, even when :attr:`A` is real. The eigenvectors | 
|  | will be given by the columns of `eigenvectors`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> A | 
|  | tensor([[ 0.9828+0.3889j, -0.4617+0.3010j], | 
|  | [ 0.1662-0.7435j, -0.6139+0.0562j]], dtype=torch.complex128) | 
|  | >>> L, V = torch.linalg.eig(A) | 
|  | >>> L | 
|  | tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) | 
|  | >>> V | 
|  | tensor([[ 0.9218+0.0000j,  0.1882-0.2220j], | 
|  | [-0.0270-0.3867j,  0.9567+0.0000j]], dtype=torch.complex128) | 
|  | >>> torch.dist(V @ torch.diag(L) @ torch.linalg.inv(V), A) | 
|  | tensor(7.7119e-16, dtype=torch.float64) | 
|  |  | 
|  | >>> A = torch.randn(3, 2, 2, dtype=torch.float64) | 
|  | >>> L, V = torch.linalg.eig(A) | 
|  | >>> torch.dist(V @ torch.diag_embed(L) @ torch.linalg.inv(V), A) | 
|  | tensor(3.2841e-16, dtype=torch.float64) | 
|  |  | 
|  | .. _diagonalizable: | 
|  | https://en.wikipedia.org/wiki/Diagonalizable_matrix#Definition | 
|  | """) | 
|  |  | 
|  | eigvals = _add_docstr(_linalg.linalg_eigvals, r""" | 
|  | linalg.eigvals(A, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the eigenvalues of a square matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **eigenvalues** of a square matrix :math:`A \in \mathbb{K}^{n \times n}` are defined | 
|  | as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{C}} | 
|  |  | 
|  | where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. note:: The eigenvalues of a real matrix may be complex, as the roots of a real polynomial may be complex. | 
|  |  | 
|  | The eigenvalues of a matrix are always well-defined, even when the matrix is not diagonalizable. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.eig` computes the full eigenvalue decomposition. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A complex-valued tensor cointaining the eigenvalues even when :attr:`A` is real. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> L = torch.linalg.eigvals(A) | 
|  | >>> L | 
|  | tensor([ 1.1226+0.5738j, -0.7537-0.1286j], dtype=torch.complex128) | 
|  |  | 
|  | >>> torch.dist(L, torch.linalg.eig(A).eigenvalues) | 
|  | tensor(2.4576e-07) | 
|  | """) | 
|  |  | 
|  | eigh = _add_docstr(_linalg.linalg_eigh, r""" | 
|  | linalg.eigh(A, UPLO='L', *, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the eigenvalue decomposition of a complex Hermitian or real symmetric matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **eigenvalue decomposition** of a complex Hermitian or real symmetric matrix | 
|  | :math:`A \in \mathbb{K}^{n \times n}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = Q \operatorname{diag}(\Lambda) Q^{\text{H}}\mathrlap{\qquad Q \in \mathbb{K}^{n \times n}, \Lambda \in \mathbb{R}^n} | 
|  |  | 
|  | where :math:`Q^{\text{H}}` is the conjugate transpose when :math:`Q` is complex, and the transpose when :math:`Q` is real-valued. | 
|  | :math:`Q` is orthogonal in the real case and unitary in the complex case. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | :attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: | 
|  |  | 
|  | - If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. | 
|  | - If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. | 
|  |  | 
|  | The eigenvalues are returned in ascending order. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. note:: The eigenvalues of real symmetric or complex Hermitian matrices are always real. | 
|  |  | 
|  | .. warning:: The eigenvectors of a symmetric matrix are not unique, nor are they continuous with | 
|  | respect to :attr:`A`. Due to this lack of uniqueness, different hardware and | 
|  | software may compute different eigenvectors. | 
|  |  | 
|  | This non-uniqueness is caused by the fact that multiplying an eigenvector by | 
|  | `-1` in the real case or by :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex | 
|  | case produces another set of valid eigenvectors of the matrix. | 
|  | For this reason, the loss function shall not depend on the phase of the eigenvectors, as | 
|  | this quantity is not well-defined. | 
|  | This is checked for complex inputs when computing the gradients of this function. As such, | 
|  | when inputs are complex and are on a CUDA device, the computation of the gradients | 
|  | of this function synchronizes that device with the CPU. | 
|  |  | 
|  | .. warning:: Gradients computed using the `eigenvectors` tensor will only be finite when | 
|  | :attr:`A` has distinct eigenvalues. | 
|  | Furthermore, if the distance between any two eigenvalues is close to zero, | 
|  | the gradient will be numerically unstable, as it depends on the eigenvalues | 
|  | :math:`\lambda_i` through the computation of | 
|  | :math:`\frac{1}{\min_{i \neq j} \lambda_i - \lambda_j}`. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.eigvalsh` computes only the eigenvalues of a Hermitian matrix. | 
|  | Unlike :func:`torch.linalg.eigh`, the gradients of :func:`~eigvalsh` are always | 
|  | numerically stable. | 
|  |  | 
|  | :func:`torch.linalg.cholesky` for a different decomposition of a Hermitian matrix. | 
|  | The Cholesky decomposition gives less information about the matrix but is much faster | 
|  | to compute than the eigenvalue decomposition. | 
|  |  | 
|  | :func:`torch.linalg.eig` for a (slower) function that computes the eigenvalue decomposition | 
|  | of a not necessarily Hermitian square matrix. | 
|  |  | 
|  | :func:`torch.linalg.svd` for a (slower) function that computes the more general SVD | 
|  | decomposition of matrices of any shape. | 
|  |  | 
|  | :func:`torch.linalg.qr` for another (much faster) decomposition that works on general | 
|  | matrices. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of symmetric or Hermitian matrices. | 
|  | UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part | 
|  | of :attr:`A` in the computations. Default: `'L'`. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(eigenvalues, eigenvectors)` which corresponds to :math:`\Lambda` and :math:`Q` above. | 
|  |  | 
|  | `eigenvalues` will always be real-valued, even when :attr:`A` is complex. | 
|  | It will also be ordered in ascending order. | 
|  |  | 
|  | `eigenvectors` will have the same dtype as :attr:`A` and will contain the eigenvectors as its columns. | 
|  |  | 
|  | Examples:: | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> A = A + A.T.conj()  # creates a Hermitian matrix | 
|  | >>> A | 
|  | tensor([[2.9228+0.0000j, 0.2029-0.0862j], | 
|  | [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) | 
|  | >>> L, Q = torch.linalg.eigh(A) | 
|  | >>> L | 
|  | tensor([0.3277, 2.9415], dtype=torch.float64) | 
|  | >>> Q | 
|  | tensor([[-0.0846+-0.0000j, -0.9964+0.0000j], | 
|  | [ 0.9170+0.3898j, -0.0779-0.0331j]], dtype=torch.complex128) | 
|  | >>> torch.dist(Q @ torch.diag(L.cdouble()) @ Q.T.conj(), A) | 
|  | tensor(6.1062e-16, dtype=torch.float64) | 
|  |  | 
|  | >>> A = torch.randn(3, 2, 2, dtype=torch.float64) | 
|  | >>> A = A + A.mT  # creates a batch of symmetric matrices | 
|  | >>> L, Q = torch.linalg.eigh(A) | 
|  | >>> torch.dist(Q @ torch.diag_embed(L) @ Q.mH, A) | 
|  | tensor(1.5423e-15, dtype=torch.float64) | 
|  | """) | 
|  |  | 
|  | eigvalsh = _add_docstr(_linalg.linalg_eigvalsh, r""" | 
|  | linalg.eigvalsh(A, UPLO='L', *, out=None) -> Tensor | 
|  |  | 
|  | Computes the eigenvalues of a complex Hermitian or real symmetric matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **eigenvalues** of a complex Hermitian or real symmetric  matrix :math:`A \in \mathbb{K}^{n \times n}` | 
|  | are defined as the roots (counted with multiplicity) of the polynomial `p` of degree `n` given by | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | p(\lambda) = \operatorname{det}(A - \lambda \mathrm{I}_n)\mathrlap{\qquad \lambda \in \mathbb{R}} | 
|  |  | 
|  | where :math:`\mathrm{I}_n` is the `n`-dimensional identity matrix. | 
|  | The eigenvalues of a real symmetric or complex Hermitian matrix are always real. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | The eigenvalues are returned in ascending order. | 
|  |  | 
|  | :attr:`A` is assumed to be Hermitian (resp. symmetric), but this is not checked internally, instead: | 
|  |  | 
|  | - If :attr:`UPLO`\ `= 'L'` (default), only the lower triangular part of the matrix is used in the computation. | 
|  | - If :attr:`UPLO`\ `= 'U'`, only the upper triangular part of the matrix is used. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.eigh` computes the full eigenvalue decomposition. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions | 
|  | consisting of symmetric or Hermitian matrices. | 
|  | UPLO ('L', 'U', optional): controls whether to use the upper or lower triangular part | 
|  | of :attr:`A` in the computations. Default: `'L'`. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor cointaining the eigenvalues even when :attr:`A` is complex. | 
|  | The eigenvalues are returned in ascending order. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2, dtype=torch.complex128) | 
|  | >>> A = A + A.T.conj()  # creates a Hermitian matrix | 
|  | >>> A | 
|  | tensor([[2.9228+0.0000j, 0.2029-0.0862j], | 
|  | [0.2029+0.0862j, 0.3464+0.0000j]], dtype=torch.complex128) | 
|  | >>> torch.linalg.eigvalsh(A) | 
|  | tensor([0.3277, 2.9415], dtype=torch.float64) | 
|  |  | 
|  | >>> A = torch.randn(3, 2, 2, dtype=torch.float64) | 
|  | >>> A = A + A.mT  # creates a batch of symmetric matrices | 
|  | >>> torch.linalg.eigvalsh(A) | 
|  | tensor([[ 2.5797,  3.4629], | 
|  | [-4.1605,  1.3780], | 
|  | [-3.1113,  2.7381]], dtype=torch.float64) | 
|  | """) | 
|  |  | 
|  | householder_product = _add_docstr(_linalg.linalg_householder_product, r""" | 
|  | householder_product(A, tau, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the first `n` columns of a product of Householder matrices. | 
|  |  | 
|  | Let :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, and | 
|  | let :math:`V \in \mathbb{K}^{m \times n}` be a matrix with columns :math:`v_i \in \mathbb{K}^m` | 
|  | for :math:`i=1,\ldots,m` with :math:`m \geq n`. Denote by :math:`w_i` the vector resulting from | 
|  | zeroing out the first :math:`i-1` compontents of :math:`v_i` and setting to `1` the :math:`i`-th. | 
|  | For a vector :math:`\tau \in \mathbb{K}^k` with :math:`k \leq n`, this function computes the | 
|  | first :math:`n` columns of the matrix | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | H_1H_2 ... H_k \qquad\text{with}\qquad H_i = \mathrm{I}_m - \tau_i w_i w_i^{\text{H}} | 
|  |  | 
|  | where :math:`\mathrm{I}_m` is the `m`-dimensional identity matrix and :math:`w^{\text{H}}` is the | 
|  | conjugate transpose when :math:`w` is complex, and the transpose when :math:`w` is real-valued. | 
|  | The output matrix is the same size as the input matrix :attr:`A`. | 
|  |  | 
|  | See `Representation of Orthogonal or Unitary Matrices`_ for further details. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.geqrf` can be used together with this function to form the `Q` from the | 
|  | :func:`~qr` decomposition. | 
|  |  | 
|  | :func:`torch.ormqr` is a related function that computes the matrix multiplication | 
|  | of a product of Householder matrices with another matrix. | 
|  | However, that function is not supported by autograd. | 
|  |  | 
|  | .. warning:: | 
|  | Gradient computations are only well-defined if :math:`tau_i \neq \frac{1}{||v_i||^2}`. | 
|  | If this condition is not met, no error will be thrown, but the gradient produced may contain `NaN`. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | tau (Tensor): tensor of shape `(*, k)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if :attr:`A` doesn't satisfy the requirement `m >= n`, | 
|  | or :attr:`tau` doesn't satisfy the requirement `n >= k`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 2) | 
|  | >>> h, tau = torch.geqrf(A) | 
|  | >>> Q = torch.linalg.householder_product(h, tau) | 
|  | >>> torch.dist(Q, torch.linalg.qr(A).Q) | 
|  | tensor(0.) | 
|  |  | 
|  | >>> h = torch.randn(3, 2, 2, dtype=torch.complex128) | 
|  | >>> tau = torch.randn(3, 1, dtype=torch.complex128) | 
|  | >>> Q = torch.linalg.householder_product(h, tau) | 
|  | >>> Q | 
|  | tensor([[[ 1.8034+0.4184j,  0.2588-1.0174j], | 
|  | [-0.6853+0.7953j,  2.0790+0.5620j]], | 
|  |  | 
|  | [[ 1.4581+1.6989j, -1.5360+0.1193j], | 
|  | [ 1.3877-0.6691j,  1.3512+1.3024j]], | 
|  |  | 
|  | [[ 1.4766+0.5783j,  0.0361+0.6587j], | 
|  | [ 0.6396+0.1612j,  1.3693+0.4481j]]], dtype=torch.complex128) | 
|  |  | 
|  | .. _Representation of Orthogonal or Unitary Matrices: | 
|  | https://www.netlib.org/lapack/lug/node128.html | 
|  | """) | 
|  |  | 
|  | ldl_factor = _add_docstr(_linalg.linalg_ldl_factor, r""" | 
|  | linalg.ldl_factor(A, *, hermitian=False, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes a compact representation of the LDL factorization of a Hermitian or symmetric (possibly indefinite) matrix. | 
|  |  | 
|  | When :attr:`A` is complex valued it can be Hermitian (:attr:`hermitian`\ `= True`) | 
|  | or symmetric (:attr:`hermitian`\ `= False`). | 
|  |  | 
|  | The factorization is of the form the form :math:`A = L D L^T`. | 
|  | If :attr:`hermitian` is `True` then transpose operation is the conjugate transpose. | 
|  |  | 
|  | :math:`L` (or :math:`U`) and :math:`D` are stored in compact form in ``LD``. | 
|  | They follow the format specified by `LAPACK's sytrf`_ function. | 
|  | These tensors may be used in :func:`torch.linalg.ldl_solve` to solve linear systems. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.ldl_factor_ex")} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape (*, n, n) where * is zero or more batch dimensions consisting of symmetric or Hermitian matrices. | 
|  | `(*, n, n)` where `*` is one or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. | 
|  | For real-valued matrices, this switch has no effect. Default: `False`. | 
|  | out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(LD, pivots)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> A = A @ A.mT # make symmetric | 
|  | >>> A | 
|  | tensor([[7.2079, 4.2414, 1.9428], | 
|  | [4.2414, 3.4554, 0.3264], | 
|  | [1.9428, 0.3264, 1.3823]]) | 
|  | >>> LD, pivots = torch.linalg.ldl_factor(A) | 
|  | >>> LD | 
|  | tensor([[ 7.2079,  0.0000,  0.0000], | 
|  | [ 0.5884,  0.9595,  0.0000], | 
|  | [ 0.2695, -0.8513,  0.1633]]) | 
|  | >>> pivots | 
|  | tensor([1, 2, 3], dtype=torch.int32) | 
|  |  | 
|  | .. _LAPACK's sytrf: | 
|  | https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html | 
|  | """) | 
|  |  | 
|  | ldl_factor_ex = _add_docstr(_linalg.linalg_ldl_factor_ex, r""" | 
|  | linalg.ldl_factor_ex(A, *, hermitian=False, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) | 
|  |  | 
|  | This is a version of :func:`~ldl_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. | 
|  | It also returns the :attr:`info` tensor returned by `LAPACK's sytrf`_. | 
|  | ``info`` stores integer error codes from the backend library. | 
|  | A positive integer indicates the diagonal element of :math:`D` that is zero. | 
|  | Division by 0 will occur if the result is used for solving a system of linear equations. | 
|  | ``info`` filled with zeros indicates that the factorization was successful. | 
|  | If ``check_errors=True`` and ``info`` contains positive integers, then a `RuntimeError` is thrown. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_ex"]} | 
|  |  | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape (*, n, n) where * is zero or more batch dimensions consisting of symmetric or Hermitian matrices. | 
|  | `(*, n, n)` where `*` is one or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. | 
|  | For real-valued matrices, this switch has no effect. Default: `False`. | 
|  | check_errors (bool, optional): controls whether to check the content of ``info`` and raise | 
|  | an error if it is non-zero. Default: `False`. | 
|  | out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(LD, pivots, info)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> A = A @ A.mT # make symmetric | 
|  | >>> A | 
|  | tensor([[7.2079, 4.2414, 1.9428], | 
|  | [4.2414, 3.4554, 0.3264], | 
|  | [1.9428, 0.3264, 1.3823]]) | 
|  | >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) | 
|  | >>> LD | 
|  | tensor([[ 7.2079,  0.0000,  0.0000], | 
|  | [ 0.5884,  0.9595,  0.0000], | 
|  | [ 0.2695, -0.8513,  0.1633]]) | 
|  | >>> pivots | 
|  | tensor([1, 2, 3], dtype=torch.int32) | 
|  | >>> info | 
|  | tensor(0, dtype=torch.int32) | 
|  |  | 
|  | .. _LAPACK's sytrf: | 
|  | https://www.netlib.org/lapack/explore-html/d3/db6/group__double_s_ycomputational_gad91bde1212277b3e909eb6af7f64858a.html | 
|  | """) | 
|  |  | 
|  | ldl_solve = _add_docstr(_linalg.linalg_ldl_solve, r""" | 
|  | linalg.ldl_solve(LD, pivots, B, *, hermitian=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the solution of a system of linear equations using the LDL factorization. | 
|  |  | 
|  | :attr:`LD` and :attr:`pivots` are the compact representation of the LDL factorization and | 
|  | are expected to be computed by :func:`torch.linalg.ldl_factor_ex`. | 
|  | :attr:`hermitian` argument to this function should be the same | 
|  | as the corresponding argumens in :func:`torch.linalg.ldl_factor_ex`. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | LD (Tensor): the `n \times n` matrix or the batch of such matrices of size | 
|  | `(*, n, n)` where `*` is one or more batch dimensions. | 
|  | pivots (Tensor): the pivots corresponding to the LDL factorization of :attr:`LD`. | 
|  | B (Tensor): right-hand side tensor of shape `(*, n, k)`. | 
|  |  | 
|  | Keyword args: | 
|  | hermitian (bool, optional): whether to consider the decomposed matrix to be Hermitian or symmetric. | 
|  | For real-valued matrices, this switch has no effect. Default: `False`. | 
|  | out (tuple, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. | 
|  | Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 3) | 
|  | >>> A = A @ A.mT # make symmetric | 
|  | >>> LD, pivots, info = torch.linalg.ldl_factor_ex(A) | 
|  | >>> B = torch.randn(2, 3, 4) | 
|  | >>> X = torch.linalg.ldl_solve(LD, pivots, B) | 
|  | >>> torch.linalg.norm(A @ X - B) | 
|  | >>> tensor(0.0001) | 
|  | """) | 
|  |  | 
|  | lstsq = _add_docstr(_linalg.linalg_lstsq, r""" | 
|  | torch.linalg.lstsq(A, B, rcond=None, *, driver=None) -> (Tensor, Tensor, Tensor, Tensor) | 
|  |  | 
|  | Computes a solution to the least squares problem of a system of linear equations. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **least squares problem** for a linear system :math:`AX = B` with | 
|  | :math:`A \in \mathbb{K}^{m \times n}, B \in \mathbb{K}^{m \times k}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \min_{X \in \mathbb{K}^{n \times k}} \|AX - B\|_F | 
|  |  | 
|  | where :math:`\|-\|_F` denotes the Frobenius norm. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | :attr:`driver` chooses the LAPACK/MAGMA function that will be used. | 
|  | For CPU inputs the valid values are `'gels'`, `'gelsy'`, `'gelsd`, `'gelss'`. | 
|  | For CUDA input, the only valid driver is `'gels'`, which assumes that :attr:`A` is full-rank. | 
|  | To choose the best driver on CPU consider: | 
|  |  | 
|  | - If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. | 
|  |  | 
|  | - For a general matrix: `'gelsy'` (QR with pivoting) (default) | 
|  | - If :attr:`A` is full-rank: `'gels'` (QR) | 
|  |  | 
|  | - If :attr:`A` is not well-conditioned. | 
|  |  | 
|  | - `'gelsd'` (tridiagonal reduction and SVD) | 
|  | - But if you run into memory issues: `'gelss'` (full SVD). | 
|  |  | 
|  | See also the `full description of these drivers`_ | 
|  |  | 
|  | :attr:`rcond` is used to determine the effective rank of the matrices in :attr:`A` | 
|  | when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`). | 
|  | In this case, if :math:`\sigma_i` are the singular values of `A` in decreasing order, | 
|  | :math:`\sigma_i` will be rounded down to zero if :math:`\sigma_i \leq \text{rcond} \cdot \sigma_1`. | 
|  | If :attr:`rcond`\ `= None` (default), :attr:`rcond` is set to the machine precision of the dtype of :attr:`A` times `max(m, n)`. | 
|  |  | 
|  | This function returns the solution to the problem and some extra information in a named tuple of | 
|  | four tensors `(solution, residuals, rank, singular_values)`. For inputs :attr:`A`, :attr:`B` | 
|  | of shape `(*, m, n)`, `(*, m, k)` respectively, it cointains | 
|  |  | 
|  | - `solution`: the least squares solution. It has shape `(*, n, k)`. | 
|  | - `residuals`: the squared residuals of the solutions, that is, :math:`\|AX - B\|_F^2`. | 
|  | It has shape equal to the batch dimensions of :attr:`A`. | 
|  | It is computed when `m > n` and every matrix in :attr:`A` is full-rank, | 
|  | otherwise, it is an empty tensor. | 
|  | If :attr:`A` is a batch of matrices and any matrix in the batch is not full rank, | 
|  | then an empty tensor is returned. This behavior may change in a future PyTorch release. | 
|  | - `rank`: tensor of ranks of the matrices in :attr:`A`. | 
|  | It has shape equal to the batch dimensions of :attr:`A`. | 
|  | It is computed when :attr:`driver` is one of (`'gelsy'`, `'gelsd'`, `'gelss'`), | 
|  | otherwise it is an empty tensor. | 
|  | - `singular_values`: tensor of singular values of the matrices in :attr:`A`. | 
|  | It has shape `(*, min(m, n))`. | 
|  | It is computed when :attr:`driver` is one of (`'gelsd'`, `'gelss'`), | 
|  | otherwise it is an empty tensor. | 
|  |  | 
|  | .. note:: | 
|  | This function computes `X = \ `:attr:`A`\ `.pinverse() @ \ `:attr:`B` in a faster and | 
|  | more numerically stable way than performing the computations separately. | 
|  |  | 
|  | .. warning:: | 
|  | The default value of :attr:`rcond` may change in a future PyTorch release. | 
|  | It is therefore recommended to use a fixed value to avoid potential | 
|  | breaking changes. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): lhs tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | B (Tensor): rhs tensor of shape `(*, m, k)` where `*` is zero or more batch dimensions. | 
|  | rcond (float, optional): used to determine the effective rank of :attr:`A`. | 
|  | If :attr:`rcond`\ `= None`, :attr:`rcond` is set to the machine | 
|  | precision of the dtype of :attr:`A` times `max(m, n)`. Default: `None`. | 
|  |  | 
|  | Keyword args: | 
|  | driver (str, optional): name of the LAPACK/MAGMA method to be used. | 
|  | If `None`, `'gelsy'` is used for CPU inputs and `'gels'` for CUDA inputs. | 
|  | Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(solution, residuals, rank, singular_values)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.tensor([[[10, 2, 3], [3, 10, 5], [5, 6, 12]]], dtype=torch.float) # shape (1, 3, 3) | 
|  | >>> B = torch.tensor([[[2, 5, 1], [3, 2, 1], [5, 1, 9]], | 
|  | [[4, 2, 9], [2, 0, 3], [2, 5, 3]]], dtype=torch.float) # shape (2, 3, 3) | 
|  | >>> X = torch.linalg.lstsq(A, B).solution # A is broadcasted to shape (2, 3, 3) | 
|  | >>> torch.dist(X, torch.linalg.pinv(A) @ B) | 
|  | tensor(2.0862e-07) | 
|  |  | 
|  | >>> S = torch.linalg.lstsq(A, B, driver='gelsd').singular_values | 
|  | >>> torch.dist(S, torch.linalg.svdvals(A)) | 
|  | tensor(5.7220e-06) | 
|  |  | 
|  | >>> A[:, 0].zero_()  # Decrease the rank of A | 
|  | >>> rank = torch.linalg.lstsq(A, B).rank | 
|  | >>> rank | 
|  | tensor([2]) | 
|  |  | 
|  | .. _condition number: | 
|  | https://pytorch.org/docs/master/linalg.html#torch.linalg.cond | 
|  | .. _full description of these drivers: | 
|  | https://www.netlib.org/lapack/lug/node27.html | 
|  | """) | 
|  |  | 
|  | matrix_power = _add_docstr(_linalg.linalg_matrix_power, r""" | 
|  | matrix_power(A, n, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the `n`-th power of a square matrix for an integer `n`. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | If :attr:`n`\ `= 0`, it returns the identity matrix (or batch) of the same shape | 
|  | as :attr:`A`. If :attr:`n` is negative, it returns the inverse of each matrix | 
|  | (if invertible) raised to the power of `abs(n)`. | 
|  |  | 
|  | .. note:: | 
|  | Consider using :func:`torch.linalg.solve` if possible for multiplying a matrix on the left by | 
|  | a negative power as, if :attr:`n`\ `> 0`:: | 
|  |  | 
|  | matrix_power(torch.linalg.solve(A, B), n) == matrix_power(A, -n)  @ B | 
|  |  | 
|  | It is always prefered to use :func:`~solve` when possible, as it is faster and more | 
|  | numerically stable than computing :math:`A^{-n}` explicitly. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.solve` computes :attr:`A`\ `.inverse() @ \ `:attr:`B` with a | 
|  | numerically stable algorithm. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, m)` where `*` is zero or more batch dimensions. | 
|  | n (int): the exponent. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if :attr:`n`\ `< 0` and the matrix :attr:`A` or any matrix in the | 
|  | batch of matrices :attr:`A` is not invertible. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> torch.linalg.matrix_power(A, 0) | 
|  | tensor([[1., 0., 0.], | 
|  | [0., 1., 0.], | 
|  | [0., 0., 1.]]) | 
|  | >>> torch.linalg.matrix_power(A, 3) | 
|  | tensor([[ 1.0756,  0.4980,  0.0100], | 
|  | [-1.6617,  1.4994, -1.9980], | 
|  | [-0.4509,  0.2731,  0.8001]]) | 
|  | >>> torch.linalg.matrix_power(A.expand(2, -1, -1), -2) | 
|  | tensor([[[ 0.2640,  0.4571, -0.5511], | 
|  | [-1.0163,  0.3491, -1.5292], | 
|  | [-0.4899,  0.0822,  0.2773]], | 
|  | [[ 0.2640,  0.4571, -0.5511], | 
|  | [-1.0163,  0.3491, -1.5292], | 
|  | [-0.4899,  0.0822,  0.2773]]]) | 
|  | """) | 
|  |  | 
|  | matrix_rank = _add_docstr(_linalg.linalg_matrix_rank, r""" | 
|  | linalg.matrix_rank(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the numerical rank of a matrix. | 
|  |  | 
|  | The matrix rank is computed as the number of singular values | 
|  | (or eigenvalues in absolute value when :attr:`hermitian`\ `= True`) | 
|  | that are greater than :math:`\max(\text{atol}, \sigma_1 * \text{rtol})` threshold, | 
|  | where :math:`\sigma_1` is the largest singular value (or eigenvalue). | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or | 
|  | symmetric if real, but this is not checked internally. Instead, just the lower | 
|  | triangular part of the matrix is used in the computations. | 
|  |  | 
|  | If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, | 
|  | the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` | 
|  | and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). | 
|  | If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then | 
|  | :attr:`rtol` is set to zero. | 
|  |  | 
|  | If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that | 
|  | of the singular values of :attr:`A` as returned by :func:`torch.linalg.svdvals`. | 
|  |  | 
|  | .. note:: | 
|  | This function has NumPy compatible variant `linalg.matrix_rank(A, tol, hermitian=False)`. | 
|  | However, use of the positional argument :attr:`tol` is deprecated in favor of :attr:`atol` and :attr:`rtol`. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: The matrix rank is computed using a singular value decomposition | 
|  | :func:`torch.linalg.svdvals` if :attr:`hermitian`\ `= False` (default) and the eigenvalue | 
|  | decomposition :func:`torch.linalg.eigvalsh` when :attr:`hermitian`\ `= True`. | 
|  | {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | tol (float, Tensor, optional): [NumPy Compat] Alias for :attr:`atol`. Default: `None`. | 
|  |  | 
|  | Keyword args: | 
|  | atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. | 
|  | Default: `None`. | 
|  | rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. | 
|  | Default: `None`. | 
|  | hermitian(bool): indicates whether :attr:`A` is Hermitian if complex | 
|  | or symmetric if real. Default: `False`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.eye(10) | 
|  | >>> torch.linalg.matrix_rank(A) | 
|  | tensor(10) | 
|  | >>> B = torch.eye(10) | 
|  | >>> B[0, 0] = 0 | 
|  | >>> torch.linalg.matrix_rank(B) | 
|  | tensor(9) | 
|  |  | 
|  | >>> A = torch.randn(4, 3, 2) | 
|  | >>> torch.linalg.matrix_rank(A) | 
|  | tensor([2, 2, 2, 2]) | 
|  |  | 
|  | >>> A = torch.randn(2, 4, 2, 3) | 
|  | >>> torch.linalg.matrix_rank(A) | 
|  | tensor([[2, 2, 2, 2], | 
|  | [2, 2, 2, 2]]) | 
|  |  | 
|  | >>> A = torch.randn(2, 4, 3, 3, dtype=torch.complex64) | 
|  | >>> torch.linalg.matrix_rank(A) | 
|  | tensor([[3, 3, 3, 3], | 
|  | [3, 3, 3, 3]]) | 
|  | >>> torch.linalg.matrix_rank(A, hermitian=True) | 
|  | tensor([[3, 3, 3, 3], | 
|  | [3, 3, 3, 3]]) | 
|  | >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0) | 
|  | tensor([[3, 2, 2, 2], | 
|  | [1, 2, 1, 2]]) | 
|  | >>> torch.linalg.matrix_rank(A, atol=1.0, rtol=0.0, hermitian=True) | 
|  | tensor([[2, 2, 2, 1], | 
|  | [1, 2, 2, 2]]) | 
|  | """) | 
|  |  | 
|  | norm = _add_docstr(_linalg.linalg_norm, r""" | 
|  | linalg.norm(A, ord=None, dim=None, keepdim=False, *, out=None, dtype=None) -> Tensor | 
|  |  | 
|  | Computes a vector or matrix norm. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  |  | 
|  | Whether this function computes a vector or matrix norm is determined as follows: | 
|  |  | 
|  | - If :attr:`dim` is an `int`, the vector norm will be computed. | 
|  | - If :attr:`dim` is a `2`-`tuple`, the matrix norm will be computed. | 
|  | - If :attr:`dim`\ `= None` and :attr:`ord`\ `= None`, | 
|  | :attr:`A` will be flattened to 1D and the `2`-norm of the resulting vector will be computed. | 
|  | - If :attr:`dim`\ `= None` and :attr:`ord` `!= None`, :attr:`A` must be 1D or 2D. | 
|  |  | 
|  | :attr:`ord` defines the norm that is computed. The following norms are supported: | 
|  |  | 
|  | ======================     =========================  ======================================================== | 
|  | :attr:`ord`                norm for matrices          norm for vectors | 
|  | ======================     =========================  ======================================================== | 
|  | `None` (default)           Frobenius norm             `2`-norm (see below) | 
|  | `'fro'`                    Frobenius norm             -- not supported -- | 
|  | `'nuc'`                    nuclear norm               -- not supported -- | 
|  | `inf`                      `max(sum(abs(x), dim=1))`  `max(abs(x))` | 
|  | `-inf`                     `min(sum(abs(x), dim=1))`  `min(abs(x))` | 
|  | `0`                        -- not supported --        `sum(x != 0)` | 
|  | `1`                        `max(sum(abs(x), dim=0))`  as below | 
|  | `-1`                       `min(sum(abs(x), dim=0))`  as below | 
|  | `2`                        largest singular value     as below | 
|  | `-2`                       smallest singular value    as below | 
|  | other `int` or `float`     -- not supported --        `sum(abs(x)^{ord})^{(1 / ord)}` | 
|  | ======================     =========================  ======================================================== | 
|  |  | 
|  | where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.vector_norm` computes a vector norm. | 
|  |  | 
|  | :func:`torch.linalg.matrix_norm` computes a matrix norm. | 
|  |  | 
|  | The above functions are often clearer and more flexible than using :func:`torch.linalg.norm`. | 
|  | For example, `torch.linalg.norm(A, ord=1, dim=(0, 1))` always | 
|  | computes a matrix norm, but with `torch.linalg.vector_norm(A, ord=1, dim=(0, 1))` it is possible | 
|  | to compute a vector norm over the two dimensions. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n)` or `(*, m, n)` where `*` is zero or more batch dimensions | 
|  | ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `None` | 
|  | dim (int, Tuple[int], optional): dimensions over which to compute | 
|  | the vector or matrix norm. See above for the behavior when :attr:`dim`\ `= None`. | 
|  | Default: `None` | 
|  | keepdim (bool, optional): If set to `True`, the reduced dimensions are retained | 
|  | in the result as dimensions with size one. Default: `False` | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  | dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to | 
|  | :attr:`dtype` before performing the operation, and the returned tensor's type | 
|  | will be :attr:`dtype`. Default: `None` | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor, even when :attr:`A` is complex. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> from torch import linalg as LA | 
|  | >>> a = torch.arange(9, dtype=torch.float) - 4 | 
|  | >>> a | 
|  | tensor([-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.]) | 
|  | >>> B = a.reshape((3, 3)) | 
|  | >>> B | 
|  | tensor([[-4., -3., -2.], | 
|  | [-1.,  0.,  1.], | 
|  | [ 2.,  3.,  4.]]) | 
|  |  | 
|  | >>> LA.norm(a) | 
|  | tensor(7.7460) | 
|  | >>> LA.norm(B) | 
|  | tensor(7.7460) | 
|  | >>> LA.norm(B, 'fro') | 
|  | tensor(7.7460) | 
|  | >>> LA.norm(a, float('inf')) | 
|  | tensor(4.) | 
|  | >>> LA.norm(B, float('inf')) | 
|  | tensor(9.) | 
|  | >>> LA.norm(a, -float('inf')) | 
|  | tensor(0.) | 
|  | >>> LA.norm(B, -float('inf')) | 
|  | tensor(2.) | 
|  |  | 
|  | >>> LA.norm(a, 1) | 
|  | tensor(20.) | 
|  | >>> LA.norm(B, 1) | 
|  | tensor(7.) | 
|  | >>> LA.norm(a, -1) | 
|  | tensor(0.) | 
|  | >>> LA.norm(B, -1) | 
|  | tensor(6.) | 
|  | >>> LA.norm(a, 2) | 
|  | tensor(7.7460) | 
|  | >>> LA.norm(B, 2) | 
|  | tensor(7.3485) | 
|  |  | 
|  | >>> LA.norm(a, -2) | 
|  | tensor(0.) | 
|  | >>> LA.norm(B.double(), -2) | 
|  | tensor(1.8570e-16, dtype=torch.float64) | 
|  | >>> LA.norm(a, 3) | 
|  | tensor(5.8480) | 
|  | >>> LA.norm(a, -3) | 
|  | tensor(0.) | 
|  |  | 
|  | Using the :attr:`dim` argument to compute vector norms:: | 
|  |  | 
|  | >>> c = torch.tensor([[1., 2., 3.], | 
|  | ...                   [-1, 1, 4]]) | 
|  | >>> LA.norm(c, dim=0) | 
|  | tensor([1.4142, 2.2361, 5.0000]) | 
|  | >>> LA.norm(c, dim=1) | 
|  | tensor([3.7417, 4.2426]) | 
|  | >>> LA.norm(c, ord=1, dim=1) | 
|  | tensor([6., 6.]) | 
|  |  | 
|  | Using the :attr:`dim` argument to compute matrix norms:: | 
|  |  | 
|  | >>> A = torch.arange(8, dtype=torch.float).reshape(2, 2, 2) | 
|  | >>> LA.norm(A, dim=(1,2)) | 
|  | tensor([ 3.7417, 11.2250]) | 
|  | >>> LA.norm(A[0, :, :]), LA.norm(A[1, :, :]) | 
|  | (tensor(3.7417), tensor(11.2250)) | 
|  | """) | 
|  |  | 
|  | vector_norm = _add_docstr(_linalg.linalg_vector_norm, r""" | 
|  | linalg.vector_norm(x, ord=2, dim=None, keepdim=False, *, dtype=None, out=None) -> Tensor | 
|  |  | 
|  | Computes a vector norm. | 
|  |  | 
|  | If :attr:`x` is complex valued, it computes the norm of :attr:`x`\ `.abs()` | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  |  | 
|  | This function does not necessarily treat multidimensonal :attr:`x` as a batch of | 
|  | vectors, instead: | 
|  |  | 
|  | - If :attr:`dim`\ `= None`, :attr:`x` will be flattened before the norm is computed. | 
|  | - If :attr:`dim` is an `int` or a `tuple`, the norm will be computed over these dimensions | 
|  | and the other dimensions will be treated as batch dimensions. | 
|  |  | 
|  | This behavior is for consistency with :func:`torch.linalg.norm`. | 
|  |  | 
|  | :attr:`ord` defines the vector norm that is computed. The following norms are supported: | 
|  |  | 
|  | ======================   =============================== | 
|  | :attr:`ord`              vector norm | 
|  | ======================   =============================== | 
|  | `2` (default)            `2`-norm (see below) | 
|  | `inf`                    `max(abs(x))` | 
|  | `-inf`                   `min(abs(x))` | 
|  | `0`                      `sum(x != 0)` | 
|  | other `int` or `float`   `sum(abs(x)^{ord})^{(1 / ord)}` | 
|  | ======================   =============================== | 
|  |  | 
|  | where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. | 
|  |  | 
|  | :attr:`dtype` may be used to perform the computation in a more precise dtype. | 
|  | It is semantically equivalent to calling ``linalg.vector_norm(x.to(dtype))`` | 
|  | but it is faster in some cases. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.matrix_norm` computes a matrix norm. | 
|  |  | 
|  | Args: | 
|  | x (Tensor): tensor, flattened by default, but this behavior can be | 
|  | controlled using :attr:`dim`. | 
|  | ord (int, float, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `2` | 
|  | dim (int, Tuple[int], optional): dimensions over which to compute | 
|  | the norm. See above for the behavior when :attr:`dim`\ `= None`. | 
|  | Default: `None` | 
|  | keepdim (bool, optional): If set to `True`, the reduced dimensions are retained | 
|  | in the result as dimensions with size one. Default: `False` | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  | dtype (:class:`torch.dtype`, optional): type used to perform the accumulation and the return. | 
|  | If specified, :attr:`x` is cast to :attr:`dtype` before performing the operation, | 
|  | and the returned tensor’s type will be :attr:`dtype` if real and of its real counterpart if complex. | 
|  | :attr:`dtype` may be complex if :attr:`x` is complex, otherwise it must be real. | 
|  | :attr:`x` should be convertible without narrowing to :attr:`dtype`. Default: None | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor, even when :attr:`x` is complex. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> from torch import linalg as LA | 
|  | >>> a = torch.arange(9, dtype=torch.float) - 4 | 
|  | >>> a | 
|  | tensor([-4., -3., -2., -1.,  0.,  1.,  2.,  3.,  4.]) | 
|  | >>> B = a.reshape((3, 3)) | 
|  | >>> B | 
|  | tensor([[-4., -3., -2.], | 
|  | [-1.,  0.,  1.], | 
|  | [ 2.,  3.,  4.]]) | 
|  | >>> LA.vector_norm(a, ord=3.5) | 
|  | tensor(5.4345) | 
|  | >>> LA.vector_norm(B, ord=3.5) | 
|  | tensor(5.4345) | 
|  | """) | 
|  |  | 
|  | matrix_norm = _add_docstr(_linalg.linalg_matrix_norm, r""" | 
|  | linalg.matrix_norm(A, ord='fro', dim=(-2, -1), keepdim=False, *, dtype=None, out=None) -> Tensor | 
|  |  | 
|  | Computes a matrix norm. | 
|  |  | 
|  | If :attr:`A` is complex valued, it computes the norm of :attr:`A`\ `.abs()` | 
|  |  | 
|  | Support input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices: the norm will be computed over the | 
|  | dimensions specified by the 2-tuple :attr:`dim` and the other dimensions will | 
|  | be treated as batch dimensions. The output will have the same batch dimensions. | 
|  |  | 
|  | :attr:`ord` defines the matrix norm that is computed. The following norms are supported: | 
|  |  | 
|  | ======================   ======================================================== | 
|  | :attr:`ord`              matrix norm | 
|  | ======================   ======================================================== | 
|  | `'fro'` (default)        Frobenius norm | 
|  | `'nuc'`                  nuclear norm | 
|  | `inf`                    `max(sum(abs(x), dim=1))` | 
|  | `-inf`                   `min(sum(abs(x), dim=1))` | 
|  | `1`                      `max(sum(abs(x), dim=0))` | 
|  | `-1`                     `min(sum(abs(x), dim=0))` | 
|  | `2`                      largest singular value | 
|  | `-2`                     smallest singular value | 
|  | ======================   ======================================================== | 
|  |  | 
|  | where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor with two or more dimensions. By default its | 
|  | shape is interpreted as `(*, m, n)` where `*` is zero or more | 
|  | batch dimensions, but this behavior can be controlled using :attr:`dim`. | 
|  | ord (int, inf, -inf, 'fro', 'nuc', optional): order of norm. Default: `'fro'` | 
|  | dim (Tuple[int, int], optional): dimensions over which to compute the norm. Default: `(-2, -1)` | 
|  | keepdim (bool, optional): If set to `True`, the reduced dimensions are retained | 
|  | in the result as dimensions with size one. Default: `False` | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  | dtype (:class:`torch.dtype`, optional): If specified, the input tensor is cast to | 
|  | :attr:`dtype` before performing the operation, and the returned tensor's type | 
|  | will be :attr:`dtype`. Default: `None` | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor, even when :attr:`A` is complex. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> from torch import linalg as LA | 
|  | >>> A = torch.arange(9, dtype=torch.float).reshape(3, 3) | 
|  | >>> A | 
|  | tensor([[0., 1., 2.], | 
|  | [3., 4., 5.], | 
|  | [6., 7., 8.]]) | 
|  | >>> LA.matrix_norm(A) | 
|  | tensor(14.2829) | 
|  | >>> LA.matrix_norm(A, ord=-1) | 
|  | tensor(9.) | 
|  | >>> B = A.expand(2, -1, -1) | 
|  | >>> B | 
|  | tensor([[[0., 1., 2.], | 
|  | [3., 4., 5.], | 
|  | [6., 7., 8.]], | 
|  |  | 
|  | [[0., 1., 2.], | 
|  | [3., 4., 5.], | 
|  | [6., 7., 8.]]]) | 
|  | >>> LA.matrix_norm(B) | 
|  | tensor([14.2829, 14.2829]) | 
|  | >>> LA.matrix_norm(B, dim=(0, 2)) | 
|  | tensor([ 3.1623, 10.0000, 17.2627]) | 
|  | """) | 
|  |  | 
|  | matmul = _add_docstr(_linalg.linalg_matmul, r""" | 
|  | linalg.matmul(input, other, *, out=None) -> Tensor | 
|  |  | 
|  | Alias for :func:`torch.matmul` | 
|  | """) | 
|  |  | 
|  | diagonal = _add_docstr(_linalg.linalg_diagonal, r""" | 
|  | linalg.diagonal(A, *, offset=0, dim1=-2, dim2=-1) -> Tensor | 
|  |  | 
|  | Alias for :func:`torch.diagonal` with defaults :attr:`dim1`\ `= -2`, :attr:`dim2`\ `= -1`. | 
|  | """) | 
|  |  | 
|  | multi_dot = _add_docstr(_linalg.linalg_multi_dot, r""" | 
|  | linalg.multi_dot(tensors, *, out=None) | 
|  |  | 
|  | Efficiently multiplies two or more matrices by reordering the multiplications so that | 
|  | the fewest arithmetic operations are performed. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | This function does not support batched inputs. | 
|  |  | 
|  | Every tensor in :attr:`tensors` must be 2D, except for the first and last which | 
|  | may be 1D. If the first tensor is a 1D vector of shape `(n,)` it is treated as a row vector | 
|  | of shape `(1, n)`, similarly if the last tensor is a 1D vector of shape `(n,)` it is treated | 
|  | as a column vector of shape `(n, 1)`. | 
|  |  | 
|  | If the first and last tensors are matrices, the output will be a matrix. | 
|  | However, if either is a 1D vector, then the output will be a 1D vector. | 
|  |  | 
|  | Differences with `numpy.linalg.multi_dot`: | 
|  |  | 
|  | - Unlike `numpy.linalg.multi_dot`, the first and last tensors must either be 1D or 2D | 
|  | whereas NumPy allows them to be nD | 
|  |  | 
|  | .. warning:: This function does not broadcast. | 
|  |  | 
|  | .. note:: This function is implemented by chaining :func:`torch.mm` calls after | 
|  | computing the optimal matrix multiplication order. | 
|  |  | 
|  | .. note:: The cost of multiplying two matrices with shapes `(a, b)` and `(b, c)` is | 
|  | `a * b * c`. Given matrices `A`, `B`, `C` with shapes `(10, 100)`, | 
|  | `(100, 5)`, `(5, 50)` respectively, we can calculate the cost of different | 
|  | multiplication orders as follows: | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \begin{align*} | 
|  | \operatorname{cost}((AB)C) &= 10 \times 100 \times 5 + 10 \times 5 \times 50 = 7500 \\ | 
|  | \operatorname{cost}(A(BC)) &= 10 \times 100 \times 50 + 100 \times 5 \times 50 = 75000 | 
|  | \end{align*} | 
|  |  | 
|  | In this case, multiplying `A` and `B` first followed by `C` is 10 times faster. | 
|  |  | 
|  | Args: | 
|  | tensors (Sequence[Tensor]): two or more tensors to multiply. The first and last | 
|  | tensors may be 1D or 2D. Every other tensor must be 2D. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> from torch.linalg import multi_dot | 
|  |  | 
|  | >>> multi_dot([torch.tensor([1, 2]), torch.tensor([2, 3])]) | 
|  | tensor(8) | 
|  | >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([2, 3])]) | 
|  | tensor([8]) | 
|  | >>> multi_dot([torch.tensor([[1, 2]]), torch.tensor([[2], [3]])]) | 
|  | tensor([[8]]) | 
|  |  | 
|  | >>> A = torch.arange(2 * 3).view(2, 3) | 
|  | >>> B = torch.arange(3 * 2).view(3, 2) | 
|  | >>> C = torch.arange(2 * 2).view(2, 2) | 
|  | >>> multi_dot((A, B, C)) | 
|  | tensor([[ 26,  49], | 
|  | [ 80, 148]]) | 
|  | """) | 
|  |  | 
|  | svd = _add_docstr(_linalg.linalg_svd, r""" | 
|  | linalg.svd(A, full_matrices=True, *, driver=None, out=None) -> (Tensor, Tensor, Tensor) | 
|  |  | 
|  | Computes the singular value decomposition (SVD) of a matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **full SVD** of a matrix | 
|  | :math:`A \in \mathbb{K}^{m \times n}`, if `k = min(m,n)`, is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = U \operatorname{diag}(S) V^{\text{H}} | 
|  | \mathrlap{\qquad U \in \mathbb{K}^{m \times m}, S \in \mathbb{R}^k, V \in \mathbb{K}^{n \times n}} | 
|  |  | 
|  | where :math:`\operatorname{diag}(S) \in \mathbb{K}^{m \times n}`, | 
|  | :math:`V^{\text{H}}` is the conjugate transpose when :math:`V` is complex, and the transpose when :math:`V` is real-valued. | 
|  | The matrices  :math:`U`, :math:`V` (and thus :math:`V^{\text{H}}`) are orthogonal in the real case, and unitary in the complex case. | 
|  |  | 
|  | When `m > n` (resp. `m < n`) we can drop the last `m - n` (resp. `n - m`) columns of `U` (resp. `V`) to form the **reduced SVD**: | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = U \operatorname{diag}(S) V^{\text{H}} | 
|  | \mathrlap{\qquad U \in \mathbb{K}^{m \times k}, S \in \mathbb{R}^k, V \in \mathbb{K}^{k \times n}} | 
|  |  | 
|  | where :math:`\operatorname{diag}(S) \in \mathbb{K}^{k \times k}`. | 
|  | In this case, :math:`U` and :math:`V` also have orthonormal columns. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | The returned decomposition is a named tuple `(U, S, Vh)` | 
|  | which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. | 
|  |  | 
|  | The singular values are returned in descending order. | 
|  |  | 
|  | The parameter :attr:`full_matrices` chooses between the full (default) and reduced SVD. | 
|  |  | 
|  | The :attr:`driver` kwarg may be used in CUDA with a cuSOLVER backend to choose the algorithm used to compute the SVD. | 
|  | The choice of a driver is a trade-off between accuracy and speed. | 
|  |  | 
|  | - If :attr:`A` is well-conditioned (its `condition number`_ is not too large), or you do not mind some precision loss. | 
|  |  | 
|  | - For a general matrix: `'gesvdj'` (Jacobi method) | 
|  | - If :attr:`A` is tall or wide (`m >> n` or `m << n`): `'gesvda'` (Approximate method) | 
|  |  | 
|  | - If :attr:`A` is not well-conditioned or precision is relevant: `'gesvd'` (QR based) | 
|  |  | 
|  | By default (:attr:`driver`\ `= None`), we call `'gesvdj'` and, if it fails, we fallback to `'gesvd'`. | 
|  |  | 
|  | Differences with `numpy.linalg.svd`: | 
|  |  | 
|  | - Unlike `numpy.linalg.svd`, this function always returns a tuple of three tensors | 
|  | and it doesn't support `compute_uv` argument. | 
|  | Please use :func:`torch.linalg.svdvals`, which computes only the singular values, | 
|  | instead of `compute_uv=False`. | 
|  |  | 
|  | .. note:: When :attr:`full_matrices`\ `= True`, the gradients with respect to `U[..., :, min(m, n):]` | 
|  | and `Vh[..., min(m, n):, :]` will be ignored, as those vectors can be arbitrary bases | 
|  | of the corresponding subspaces. | 
|  |  | 
|  | .. warning:: The returned tensors `U` and `V` are not unique, nor are they continuous with | 
|  | respect to :attr:`A`. | 
|  | Due to this lack of uniqueness, different hardware and software may compute | 
|  | different singular vectors. | 
|  |  | 
|  | This non-uniqueness is caused by the fact that multiplying any pair of singular | 
|  | vectors :math:`u_k, v_k` by `-1` in the real case or by | 
|  | :math:`e^{i \phi}, \phi \in \mathbb{R}` in the complex case produces another two | 
|  | valid singular vectors of the matrix. | 
|  | For this reason, the loss function shall not depend on this :math:`e^{i \phi}` quantity, | 
|  | as it is not well-defined. | 
|  | This is checked for complex inputs when computing the gradients of this function. As such, | 
|  | when inputs are complex and are on a CUDA device, the computation of the gradients | 
|  | of this function synchronizes that device with the CPU. | 
|  |  | 
|  | .. warning:: Gradients computed using `U` or `Vh` will only be finite when | 
|  | :attr:`A` does not have repeated singular values. If :attr:`A` is rectangular, | 
|  | additionally, zero must also not be one of its singular values. | 
|  | Furthermore, if the distance between any two singular values is close to zero, | 
|  | the gradient will be numerically unstable, as it depends on the singular values | 
|  | :math:`\sigma_i` through the computation of | 
|  | :math:`\frac{1}{\min_{i \neq j} \sigma_i^2 - \sigma_j^2}`. | 
|  | In the rectangular case, the gradient will also be numerically unstable when | 
|  | :attr:`A` has small singular values, as it also depends on the computation of | 
|  | :math:`\frac{1}{\sigma_i}`. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.svdvals` computes only the singular values. | 
|  | Unlike :func:`torch.linalg.svd`, the gradients of :func:`~svdvals` are always | 
|  | numerically stable. | 
|  |  | 
|  | :func:`torch.linalg.eig` for a function that computes another type of spectral | 
|  | decomposition of a matrix. The eigendecomposition works just on square matrices. | 
|  |  | 
|  | :func:`torch.linalg.eigh` for a (faster) function that computes the eigenvalue decomposition | 
|  | for Hermitian and symmetric matrices. | 
|  |  | 
|  | :func:`torch.linalg.qr` for another (much faster) decomposition that works on general | 
|  | matrices. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | full_matrices (bool, optional): controls whether to compute the full or reduced | 
|  | SVD, and consequently, | 
|  | the shape of the returned tensors | 
|  | `U` and `Vh`. Default: `True`. | 
|  |  | 
|  | Keyword args: | 
|  | driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. | 
|  | Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. | 
|  | Default: `None`. | 
|  | out (tuple, optional): output tuple of three tensors. Ignored if `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(U, S, Vh)` which corresponds to :math:`U`, :math:`S`, :math:`V^{\text{H}}` above. | 
|  |  | 
|  | `S` will always be real-valued, even when :attr:`A` is complex. | 
|  | It will also be ordered in descending order. | 
|  |  | 
|  | `U` and `Vh` will have the same dtype as :attr:`A`. The left / right singular vectors will be given by | 
|  | the columns of `U` and the rows of `Vh` respectively. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(5, 3) | 
|  | >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) | 
|  | >>> U.shape, S.shape, Vh.shape | 
|  | (torch.Size([5, 3]), torch.Size([3]), torch.Size([3, 3])) | 
|  | >>> torch.dist(A, U @ torch.diag(S) @ Vh) | 
|  | tensor(1.0486e-06) | 
|  |  | 
|  | >>> U, S, Vh = torch.linalg.svd(A) | 
|  | >>> U.shape, S.shape, Vh.shape | 
|  | (torch.Size([5, 5]), torch.Size([3]), torch.Size([3, 3])) | 
|  | >>> torch.dist(A, U[:, :3] @ torch.diag(S) @ Vh) | 
|  | tensor(1.0486e-06) | 
|  |  | 
|  | >>> A = torch.randn(7, 5, 3) | 
|  | >>> U, S, Vh = torch.linalg.svd(A, full_matrices=False) | 
|  | >>> torch.dist(A, U @ torch.diag_embed(S) @ Vh) | 
|  | tensor(3.0957e-06) | 
|  |  | 
|  | .. _condition number: | 
|  | https://pytorch.org/docs/master/linalg.html#torch.linalg.cond | 
|  | .. _the resulting vectors will span the same subspace: | 
|  | https://en.wikipedia.org/wiki/Singular_value_decomposition#Singular_values,_singular_vectors,_and_their_relation_to_the_SVD | 
|  | """) | 
|  |  | 
|  | svdvals = _add_docstr(_linalg.linalg_svdvals, r""" | 
|  | linalg.svdvals(A, *, driver=None, out=None) -> Tensor | 
|  |  | 
|  | Computes the singular values of a matrix. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | The singular values are returned in descending order. | 
|  |  | 
|  | .. note:: This function is equivalent to NumPy's `linalg.svd(A, compute_uv=False)`. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.svd` computes the full singular value decomposition. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. | 
|  | Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. | 
|  | Check :func:`torch.linalg.svd` for details. | 
|  | Default: `None`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor, even when :attr:`A` is complex. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(5, 3) | 
|  | >>> S = torch.linalg.svdvals(A) | 
|  | >>> S | 
|  | tensor([2.5139, 2.1087, 1.1066]) | 
|  |  | 
|  | >>> torch.dist(S, torch.linalg.svd(A, full_matrices=False).S) | 
|  | tensor(2.4576e-07) | 
|  | """) | 
|  |  | 
|  | cond = _add_docstr(_linalg.linalg_cond, r""" | 
|  | linalg.cond(A, p=None, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the condition number of a matrix with respect to a matrix norm. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **condition number** :math:`\kappa` of a matrix | 
|  | :math:`A \in \mathbb{K}^{n \times n}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \kappa(A) = \|A\|_p\|A^{-1}\|_p | 
|  |  | 
|  | The condition number of :attr:`A` measures the numerical stability of the linear system `AX = B` | 
|  | with respect to a matrix norm. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | :attr:`p` defines the matrix norm that is computed. The following norms are supported: | 
|  |  | 
|  | =========    ================================= | 
|  | :attr:`p`    matrix norm | 
|  | =========    ================================= | 
|  | `None`       `2`-norm (largest singular value) | 
|  | `'fro'`      Frobenius norm | 
|  | `'nuc'`      nuclear norm | 
|  | `inf`        `max(sum(abs(x), dim=1))` | 
|  | `-inf`       `min(sum(abs(x), dim=1))` | 
|  | `1`          `max(sum(abs(x), dim=0))` | 
|  | `-1`         `min(sum(abs(x), dim=0))` | 
|  | `2`          largest singular value | 
|  | `-2`         smallest singular value | 
|  | =========    ================================= | 
|  |  | 
|  | where `inf` refers to `float('inf')`, NumPy's `inf` object, or any equivalent object. | 
|  |  | 
|  | For :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`, this function uses | 
|  | :func:`torch.linalg.norm` and :func:`torch.linalg.inv`. | 
|  | As such, in this case, the matrix (or every matrix in the batch) :attr:`A` has to be square | 
|  | and invertible. | 
|  |  | 
|  | For :attr:`p` in `(2, -2)`, this function can be computed in terms of the singular values | 
|  | :math:`\sigma_1 \geq \ldots \geq \sigma_n` | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \kappa_2(A) = \frac{\sigma_1}{\sigma_n}\qquad \kappa_{-2}(A) = \frac{\sigma_n}{\sigma_1} | 
|  |  | 
|  | In these cases, it is computed using :func:`torch.linalg.svdvals`. For these norms, the matrix | 
|  | (or every matrix in the batch) :attr:`A` may have any shape. | 
|  |  | 
|  | .. note :: When inputs are on a CUDA device, this function synchronizes that device with the CPU | 
|  | if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)`. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.solve` for a function that solves linear systems of square matrices. | 
|  |  | 
|  | :func:`torch.linalg.lstsq` for a function that solves linear systems of general matrices. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions | 
|  | for :attr:`p` in `(2, -2)`, and of shape `(*, n, n)` where every matrix | 
|  | is invertible for :attr:`p` in `('fro', 'nuc', inf, -inf, 1, -1)`. | 
|  | p (int, inf, -inf, 'fro', 'nuc', optional): | 
|  | the type of the matrix norm to use in the computations (see above). Default: `None` | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A real-valued tensor, even when :attr:`A` is complex. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: | 
|  | if :attr:`p` is one of `('fro', 'nuc', inf, -inf, 1, -1)` | 
|  | and the :attr:`A` matrix or any matrix in the batch :attr:`A` is not square | 
|  | or invertible. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 4, 4, dtype=torch.complex64) | 
|  | >>> torch.linalg.cond(A) | 
|  | >>> A = torch.tensor([[1., 0, -1], [0, 1, 0], [1, 0, 1]]) | 
|  | >>> torch.linalg.cond(A) | 
|  | tensor([1.4142]) | 
|  | >>> torch.linalg.cond(A, 'fro') | 
|  | tensor(3.1623) | 
|  | >>> torch.linalg.cond(A, 'nuc') | 
|  | tensor(9.2426) | 
|  | >>> torch.linalg.cond(A, float('inf')) | 
|  | tensor(2.) | 
|  | >>> torch.linalg.cond(A, float('-inf')) | 
|  | tensor(1.) | 
|  | >>> torch.linalg.cond(A, 1) | 
|  | tensor(2.) | 
|  | >>> torch.linalg.cond(A, -1) | 
|  | tensor(1.) | 
|  | >>> torch.linalg.cond(A, 2) | 
|  | tensor([1.4142]) | 
|  | >>> torch.linalg.cond(A, -2) | 
|  | tensor([0.7071]) | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 3) | 
|  | >>> torch.linalg.cond(A) | 
|  | tensor([[9.5917], | 
|  | [3.2538]]) | 
|  | >>> A = torch.randn(2, 3, 3, dtype=torch.complex64) | 
|  | >>> torch.linalg.cond(A) | 
|  | tensor([[4.6245], | 
|  | [4.5671]]) | 
|  | """) | 
|  |  | 
|  | pinv = _add_docstr(_linalg.linalg_pinv, r""" | 
|  | linalg.pinv(A, *, atol=None, rtol=None, hermitian=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the pseudoinverse (Moore-Penrose inverse) of a matrix. | 
|  |  | 
|  | The pseudoinverse may be `defined algebraically`_ | 
|  | but it is more computationally convenient to understand it `through the SVD`_ | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | If :attr:`hermitian`\ `= True`, :attr:`A` is assumed to be Hermitian if complex or | 
|  | symmetric if real, but this is not checked internally. Instead, just the lower | 
|  | triangular part of the matrix is used in the computations. | 
|  |  | 
|  | The singular values (or the norm of the eigenvalues when :attr:`hermitian`\ `= True`) | 
|  | that are below :math:`\max(\text{atol}, \sigma_1 \cdot \text{rtol})` threshold are | 
|  | treated as zero and discarded in the computation, | 
|  | where :math:`\sigma_1` is the largest singular value (or eigenvalue). | 
|  |  | 
|  | If :attr:`rtol` is not specified and :attr:`A` is a matrix of dimensions `(m, n)`, | 
|  | the relative tolerance is set to be :math:`\text{rtol} = \max(m, n) \varepsilon` | 
|  | and :math:`\varepsilon` is the epsilon value for the dtype of :attr:`A` (see :class:`.finfo`). | 
|  | If :attr:`rtol` is not specified and :attr:`atol` is specified to be larger than zero then | 
|  | :attr:`rtol` is set to zero. | 
|  |  | 
|  | If :attr:`atol` or :attr:`rtol` is a :class:`torch.Tensor`, its shape must be broadcastable to that | 
|  | of the singular values of :attr:`A` as returned by :func:`torch.linalg.svd`. | 
|  |  | 
|  | .. note:: This function uses :func:`torch.linalg.svd` if :attr:`hermitian`\ `= False` and | 
|  | :func:`torch.linalg.eigh` if :attr:`hermitian`\ `= True`. | 
|  | For CUDA inputs, this function synchronizes that device with the CPU. | 
|  |  | 
|  | .. note:: | 
|  | Consider using :func:`torch.linalg.lstsq` if possible for multiplying a matrix on the left by | 
|  | the pseudoinverse, as:: | 
|  |  | 
|  | torch.linalg.lstsq(A, B).solution == A.pinv() @ B | 
|  |  | 
|  | It is always prefered to use :func:`~lstsq` when possible, as it is faster and more | 
|  | numerically stable than computing the pseudoinverse explicitly. | 
|  |  | 
|  | .. note:: | 
|  | This function has NumPy compatible variant `linalg.pinv(A, rcond, hermitian=False)`. | 
|  | However, use of the positional argument :attr:`rcond` is deprecated in favor of :attr:`rtol`. | 
|  |  | 
|  | .. warning:: | 
|  | This function uses internally :func:`torch.linalg.svd` (or :func:`torch.linalg.eigh` | 
|  | when :attr:`hermitian`\ `= True`), so its derivative has the same problems as those of these | 
|  | functions. See the warnings in :func:`torch.linalg.svd` and :func:`torch.linalg.eigh` for | 
|  | more details. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.inv` computes the inverse of a square matrix. | 
|  |  | 
|  | :func:`torch.linalg.lstsq` computes :attr:`A`\ `.pinv() @ \ `:attr:`B` with a | 
|  | numerically stable algorithm. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | rcond (float, Tensor, optional): [NumPy Compat]. Alias for :attr:`rtol`. Default: `None`. | 
|  |  | 
|  | Keyword args: | 
|  | atol (float, Tensor, optional): the absolute tolerance value. When `None` it's considered to be zero. | 
|  | Default: `None`. | 
|  | rtol (float, Tensor, optional): the relative tolerance value. See above for the value it takes when `None`. | 
|  | Default: `None`. | 
|  | hermitian(bool, optional): indicates whether :attr:`A` is Hermitian if complex | 
|  | or symmetric if real. Default: `False`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 5) | 
|  | >>> A | 
|  | tensor([[ 0.5495,  0.0979, -1.4092, -0.1128,  0.4132], | 
|  | [-1.1143, -0.3662,  0.3042,  1.6374, -0.9294], | 
|  | [-0.3269, -0.5745, -0.0382, -0.5922, -0.6759]]) | 
|  | >>> torch.linalg.pinv(A) | 
|  | tensor([[ 0.0600, -0.1933, -0.2090], | 
|  | [-0.0903, -0.0817, -0.4752], | 
|  | [-0.7124, -0.1631, -0.2272], | 
|  | [ 0.1356,  0.3933, -0.5023], | 
|  | [-0.0308, -0.1725, -0.5216]]) | 
|  |  | 
|  | >>> A = torch.randn(2, 6, 3) | 
|  | >>> Apinv = torch.linalg.pinv(A) | 
|  | >>> torch.dist(Apinv @ A, torch.eye(3)) | 
|  | tensor(8.5633e-07) | 
|  |  | 
|  | >>> A = torch.randn(3, 3, dtype=torch.complex64) | 
|  | >>> A = A + A.T.conj()  # creates a Hermitian matrix | 
|  | >>> Apinv = torch.linalg.pinv(A, hermitian=True) | 
|  | >>> torch.dist(Apinv @ A, torch.eye(3)) | 
|  | tensor(1.0830e-06) | 
|  |  | 
|  | .. _defined algebraically: | 
|  | https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Existence_and_uniqueness | 
|  | .. _through the SVD: | 
|  | https://en.wikipedia.org/wiki/Moore%E2%80%93Penrose_inverse#Singular_value_decomposition_(SVD) | 
|  | """) | 
|  |  | 
|  | matrix_exp = _add_docstr(_linalg.linalg_matrix_exp, r""" | 
|  | linalg.matrix_exp(A) -> Tensor | 
|  |  | 
|  | Computes the matrix exponential of a square matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | this function computes the **matrix exponential** of :math:`A \in \mathbb{K}^{n \times n}`, which is defined as | 
|  |  | 
|  | .. math:: | 
|  | \mathrm{matrix_exp}(A) = \sum_{k=0}^\infty \frac{1}{k!}A^k \in \mathbb{K}^{n \times n}. | 
|  |  | 
|  | If the matrix :math:`A` has eigenvalues :math:`\lambda_i \in \mathbb{C}`, | 
|  | the matrix :math:`\mathrm{matrix_exp}(A)` has eigenvalues :math:`e^{\lambda_i} \in \mathbb{C}`. | 
|  |  | 
|  | Supports input of bfloat16, float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> A = torch.empty(2, 2, 2) | 
|  | >>> A[0, :, :] = torch.eye(2, 2) | 
|  | >>> A[1, :, :] = 2 * torch.eye(2, 2) | 
|  | >>> A | 
|  | tensor([[[1., 0.], | 
|  | [0., 1.]], | 
|  |  | 
|  | [[2., 0.], | 
|  | [0., 2.]]]) | 
|  | >>> torch.linalg.matrix_exp(A) | 
|  | tensor([[[2.7183, 0.0000], | 
|  | [0.0000, 2.7183]], | 
|  |  | 
|  | [[7.3891, 0.0000], | 
|  | [0.0000, 7.3891]]]) | 
|  |  | 
|  | >>> import math | 
|  | >>> A = torch.tensor([[0, math.pi/3], [-math.pi/3, 0]]) # A is skew-symmetric | 
|  | >>> torch.linalg.matrix_exp(A) # matrix_exp(A) = [[cos(pi/3), sin(pi/3)], [-sin(pi/3), cos(pi/3)]] | 
|  | tensor([[ 0.5000,  0.8660], | 
|  | [-0.8660,  0.5000]]) | 
|  | """) | 
|  |  | 
|  |  | 
|  | solve = _add_docstr(_linalg.linalg_solve, r""" | 
|  | linalg.solve(A, B, *, left=True, out=None) -> Tensor | 
|  |  | 
|  | Computes the solution of a square system of linear equations with a unique solution. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to | 
|  | :math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as | 
|  |  | 
|  | .. math:: AX = B | 
|  |  | 
|  | If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} | 
|  |  | 
|  | This system of linear equations has one solution if and only if :math:`A` is `invertible`_. | 
|  | This function assumes that :math:`A` is invertible. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | Letting `*` be zero or more batch dimensions, | 
|  |  | 
|  | - If :attr:`A` has shape `(*, n, n)` and :attr:`B` has shape `(*, n)` (a batch of vectors) or shape | 
|  | `(*, n, k)` (a batch of matrices or "multiple right-hand sides"), this function returns `X` of shape | 
|  | `(*, n)` or `(*, n, k)` respectively. | 
|  | - Otherwise, if :attr:`A` has shape `(*, n, n)` and  :attr:`B` has shape `(n,)`  or `(n, k)`, :attr:`B` | 
|  | is broadcasted to have shape `(*, n)` or `(*, n, k)` respectively. | 
|  | This function then returns the solution of the resulting batch of systems of linear equations. | 
|  |  | 
|  | .. note:: | 
|  | This function computes `X = \ `:attr:`A`\ `.inverse() @ \ `:attr:`B` in a faster and | 
|  | more numerically stable way than performing the computations separately. | 
|  |  | 
|  | .. note:: | 
|  | It is possible to compute the solution of the system :math:`XA = B` by passing the inputs | 
|  | :attr:`A` and :attr:`B` transposed and transposing the output returned by this function. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note"]} | 
|  | """ + r""" | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.solve_triangular` computes the solution of a triangular system of linear | 
|  | equations with a unique solution. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions. | 
|  | B (Tensor): right-hand side tensor of shape `(*, n)` or  `(*, n, k)` or `(n,)` or `(n, k)` | 
|  | according to the rules described above | 
|  |  | 
|  | Keyword args: | 
|  | left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` | 
|  | is not invertible. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> b = torch.randn(3) | 
|  | >>> x = torch.linalg.solve(A, b) | 
|  | >>> torch.allclose(A @ x, b) | 
|  | True | 
|  | >>> A = torch.randn(2, 3, 3) | 
|  | >>> B = torch.randn(2, 3, 4) | 
|  | >>> X = torch.linalg.solve(A, B) | 
|  | >>> X.shape | 
|  | torch.Size([2, 3, 4]) | 
|  | >>> torch.allclose(A @ X, B) | 
|  | True | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 3) | 
|  | >>> b = torch.randn(3, 1) | 
|  | >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3, 1) | 
|  | >>> x.shape | 
|  | torch.Size([2, 3, 1]) | 
|  | >>> torch.allclose(A @ x, b) | 
|  | True | 
|  | >>> b = torch.randn(3) | 
|  | >>> x = torch.linalg.solve(A, b) # b is broadcasted to size (2, 3) | 
|  | >>> x.shape | 
|  | torch.Size([2, 3]) | 
|  | >>> Ax = A @ x.unsqueeze(-1) | 
|  | >>> torch.allclose(Ax, b.unsqueeze(-1).expand_as(Ax)) | 
|  | True | 
|  |  | 
|  | .. _invertible: | 
|  | https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem | 
|  | """) | 
|  |  | 
|  | solve_triangular = _add_docstr(_linalg.linalg_solve_triangular, r""" | 
|  | linalg.solve_triangular(A, B, *, upper, left=True, unitriangular=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the solution of a triangular system of linear equations with a unique solution. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** | 
|  | associated to the triangular matrix :math:`A \in \mathbb{K}^{n \times n}` without zeros on the diagonal | 
|  | (that is, it is `invertible`_) and the rectangular matrix , :math:`B \in \mathbb{K}^{n \times k}`, | 
|  | which is defined as | 
|  |  | 
|  | .. math:: AX = B | 
|  |  | 
|  | The argument :attr:`upper` signals whether :math:`A` is upper or lower triangular. | 
|  |  | 
|  | If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that | 
|  | solves the system | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} | 
|  |  | 
|  | If :attr:`upper`\ `= True` (resp. `False`) just the upper (resp. lower) triangular half of :attr:`A` | 
|  | will be accessed. The elements below the main diagonal will be considered to be zero and will not be accessed. | 
|  |  | 
|  | If :attr:`unitriangular`\ `= True`, the diagonal of :attr:`A` is assumed to be ones and will not be accessed. | 
|  |  | 
|  | The result may contain `NaN` s if the diagonal of :attr:`A` contains zeros or elements that | 
|  | are very close to zero and :attr:`unitriangular`\ `= False` (default) or if the input matrix | 
|  | has very small eigenvalues. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.solve` computes the solution of a general square system of linear | 
|  | equations with a unique solution. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) | 
|  | where `*` is zero or more batch dimensions. | 
|  | B (Tensor): right-hand side tensor of shape `(*, n, k)`. | 
|  |  | 
|  | Keyword args: | 
|  | upper (bool): whether :attr:`A` is an upper or lower triangular matrix. | 
|  | left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. | 
|  | unitriangular (bool, optional): if `True`, the diagonal elements of :attr:`A` are assumed to be | 
|  | all equal to `1`. Default: `False`. | 
|  | out (Tensor, optional): output tensor. `B` may be passed as `out` and the result is computed in-place on `B`. | 
|  | Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3).triu_() | 
|  | >>> b = torch.randn(3, 4) | 
|  | >>> X = torch.linalg.solve_triangular(A, B, upper=True) | 
|  | >>> torch.allclose(A @ X, B) | 
|  | True | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 3).tril_() | 
|  | >>> B = torch.randn(2, 3, 4) | 
|  | >>> X = torch.linalg.solve_triangular(A, B, upper=False) | 
|  | >>> torch.allclose(A @ X, B) | 
|  | True | 
|  |  | 
|  | >>> A = torch.randn(2, 4, 4).tril_() | 
|  | >>> B = torch.randn(2, 3, 4) | 
|  | >>> X = torch.linalg.solve_triangular(A, B, upper=False, left=False) | 
|  | >>> torch.allclose(X @ A, B) | 
|  | True | 
|  |  | 
|  | .. _invertible: | 
|  | https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem | 
|  | """) | 
|  |  | 
|  | lu_factor = _add_docstr(_linalg.linalg_lu_factor, r""" | 
|  | linalg.lu_factor(A, *, bool pivot=True, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes a compact representation of the LU factorization with partial pivoting of a matrix. | 
|  |  | 
|  | This function computes a compact representation of the decomposition given by :func:`torch.linalg.lu`. | 
|  | If the matrix is square, this representation may be used in :func:`torch.linalg.lu_solve` | 
|  | to solve system of linear equations that share the matrix :attr:`A`. | 
|  |  | 
|  | The returned decomposition is represented as a named tuple `(LU, pivots)`. | 
|  | The ``LU`` matrix has the same shape as the input matrix ``A``. Its upper and lower triangular | 
|  | parts encode the non-constant elements of ``L`` and ``U`` of the LU decomposition of ``A``. | 
|  |  | 
|  | The returned permutation matrix is represented by a 1-indexed vector. `pivots[i] == j` represents | 
|  | that in the `i`-th step of the algorithm, the `i`-th row was permuted with the `j-1`-th row. | 
|  |  | 
|  | On CUDA, one may use :attr:`pivot`\ `= False`. In this case, this function returns the LU | 
|  | decomposition without pivoting if it exists. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_has_ex"].format("torch.linalg.lu_factor_ex")} | 
|  | """ + r""" | 
|  | .. warning:: The LU decomposition is almost never unique, as often there are different permutation | 
|  | matrices that can yield different LU decompositions. | 
|  | As such, different platforms, like SciPy, or inputs on different devices, | 
|  | may produce different valid decompositions. | 
|  |  | 
|  | Gradient computations are only supported if the input matrix is full-rank. | 
|  | If this condition is not met, no error will be thrown, but the gradient may not be finite. | 
|  | This is because the LU decomposition with pivoting is not differentiable at these points. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.lu_solve` solves a system of linear equations given the output of this | 
|  | function provided the input matrix was square and invertible. | 
|  |  | 
|  | :func:`torch.lu_unpack` unpacks the tensors returned by :func:`~lu_factor` into the three | 
|  | matrices `P, L, U` that form the decomposition. | 
|  |  | 
|  | :func:`torch.linalg.lu` computes the LU decomposition with partial pivoting of a possibly | 
|  | non-square matrix. It is a composition of :func:`~lu_factor` and :func:`torch.lu_unpack`. | 
|  |  | 
|  | :func:`torch.linalg.solve` solves a system of linear equations. It is a composition | 
|  | of :func:`~lu_factor` and :func:`~lu_solve`. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU | 
|  | decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. | 
|  | out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(LU, pivots)`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the :attr:`A` matrix is not invertible or any matrix in a batched :attr:`A` | 
|  | is not invertible. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(2, 3, 3) | 
|  | >>> B1 = torch.randn(2, 3, 4) | 
|  | >>> B2 = torch.randn(2, 3, 7) | 
|  | >>> A_factor = torch.linalg.lu_factor(A) | 
|  | >>> X1 = torch.linalg.lu_solve(A_factor, B1) | 
|  | >>> X2 = torch.linalg.lu_solve(A_factor, B2) | 
|  | >>> torch.allclose(A @ X1, B1) | 
|  | True | 
|  | >>> torch.allclose(A @ X2, B2) | 
|  | True | 
|  |  | 
|  | .. _invertible: | 
|  | https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem | 
|  | """) | 
|  |  | 
|  | lu_factor_ex = _add_docstr(_linalg.linalg_lu_factor_ex, r""" | 
|  | linalg.lu_factor_ex(A, *, pivot=True, check_errors=False, out=None) -> (Tensor, Tensor, Tensor) | 
|  |  | 
|  | This is a version of :func:`~lu_factor` that does not perform error checks unless :attr:`check_errors`\ `= True`. | 
|  | It also returns the :attr:`info` tensor returned by `LAPACK's getrf`_. | 
|  |  | 
|  | """ + fr""" | 
|  | .. note:: {common_notes["sync_note_ex"]} | 
|  |  | 
|  | .. warning:: {common_notes["experimental_warning"]} | 
|  | """ + r""" | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  |  | 
|  | Keyword args: | 
|  | pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU | 
|  | decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`. | 
|  | check_errors (bool, optional): controls whether to check the content of ``infos`` and raise | 
|  | an error if it is non-zero. Default: `False`. | 
|  | out (tuple, optional): tuple of three tensors to write the output to. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(LU, pivots, info)`. | 
|  |  | 
|  | .. _LAPACK's getrf: | 
|  | https://www.netlib.org/lapack/explore-html/dd/d9a/group__double_g_ecomputational_ga0019443faea08275ca60a734d0593e60.html | 
|  | """) | 
|  |  | 
|  | lu_solve = _add_docstr(_linalg.linalg_lu_solve, r""" | 
|  | linalg.lu_solve(LU, pivots, B, *, left=True, adjoint=False, out=None) -> Tensor | 
|  |  | 
|  | Computes the solution of a square system of linear equations with a unique solution given an LU decomposition. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | this function computes the solution :math:`X \in \mathbb{K}^{n \times k}` of the **linear system** associated to | 
|  | :math:`A \in \mathbb{K}^{n \times n}, B \in \mathbb{K}^{n \times k}`, which is defined as | 
|  |  | 
|  | .. math:: AX = B | 
|  |  | 
|  | where :math:`A` is given factorized as returned by :func:`~lu_factor`. | 
|  |  | 
|  | If :attr:`left`\ `= False`, this function returns the matrix :math:`X \in \mathbb{K}^{n \times k}` that solves the system | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | XA = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} | 
|  |  | 
|  | If  :attr:`adjoint`\ `= True` (and :attr:`left`\ `= True), given an LU factorization of :math:`A` | 
|  | this function function returns the :math:`X \in \mathbb{K}^{n \times k}` that solves the system | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A^{\text{H}}X = B\mathrlap{\qquad A \in \mathbb{K}^{k \times k}, B \in \mathbb{K}^{n \times k}.} | 
|  |  | 
|  | where :math:`A^{\text{H}}` is the conjugate transpose when :math:`A` is complex, and the | 
|  | transpose when :math:`A` is real-valued. The :attr:`left`\ `= False` case is analogous. | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if the inputs are batches of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | Args: | 
|  | LU (Tensor): tensor of shape `(*, n, n)` (or `(*, k, k)` if :attr:`left`\ `= True`) | 
|  | where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. | 
|  | pivots (Tensor): tensor of shape `(*, n)` (or `(*, k)` if :attr:`left`\ `= True`) | 
|  | where `*` is zero or more batch dimensions as returned by :func:`~lu_factor`. | 
|  | B (Tensor): right-hand side tensor of shape `(*, n, k)`. | 
|  |  | 
|  | Keyword args: | 
|  | left (bool, optional): whether to solve the system :math:`AX=B` or :math:`XA = B`. Default: `True`. | 
|  | adjoint (bool, optional): whether to solve the system :math:`AX=B` or :math:`A^{\text{H}}X = B`. Default: `False`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 3) | 
|  | >>> LU, pivots = torch.linalg.lu_factor(A) | 
|  | >>> B = torch.randn(3, 2) | 
|  | >>> X = torch.linalg.lu_solve(LU, pivots, B) | 
|  | >>> torch.allclose(A @ X, B) | 
|  | True | 
|  |  | 
|  | >>> B = torch.randn(3, 3, 2)   # Broadcasting rules apply: A is broadcasted | 
|  | >>> X = torch.linalg.lu_solve(LU, pivots, B) | 
|  | >>> torch.allclose(A @ X, B) | 
|  | True | 
|  |  | 
|  | >>> B = torch.randn(3, 5, 3) | 
|  | >>> X = torch.linalg.lu_solve(LU, pivots, B, left=False) | 
|  | >>> torch.allclose(X @ A, B) | 
|  | True | 
|  |  | 
|  | >>> B = torch.randn(3, 3, 4)   # Now solve for A^T | 
|  | >>> X = torch.linalg.lu_solve(LU, pivots, B, adjoint=True) | 
|  | >>> torch.allclose(A.mT @ X, B) | 
|  | True | 
|  |  | 
|  | .. _invertible: | 
|  | https://en.wikipedia.org/wiki/Invertible_matrix#The_invertible_matrix_theorem | 
|  | """) | 
|  |  | 
|  | lu = _add_docstr(_linalg.linalg_lu, r""" | 
|  | lu(A, *, pivot=True, out=None) -> (Tensor, Tensor, Tensor) | 
|  |  | 
|  | Computes the LU decomposition with partial pivoting of a matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **LU decomposition with partial pivoting** of a matrix | 
|  | :math:`A \in \mathbb{K}^{m \times n}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = PLU\mathrlap{\qquad P \in \mathbb{K}^{m \times m}, L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} | 
|  |  | 
|  | where `k = min(m,n)`, :math:`P` is a `permutation matrix`_, :math:`L` is lower triangular with ones on the diagonal | 
|  | and :math:`U` is upper triangular. | 
|  |  | 
|  | If :attr:`pivot`\ `= False` and :attr:`A` is on GPU, then the **LU decomposition without pivoting** is computed | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = LU\mathrlap{\qquad L \in \mathbb{K}^{m \times k}, U \in \mathbb{K}^{k \times n}} | 
|  |  | 
|  | When :attr:`pivot`\ `= False`, the returned matrix :attr:`P` will be empty. | 
|  | The LU decomposition without pivoting `may not exist`_ if any of the principal minors of :attr:`A` is singular. | 
|  | In this case, the output matrix may contain `inf` or `NaN`. | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.solve` solves a system of linear equations using the LU decomposition | 
|  | with partial pivoting. | 
|  |  | 
|  | .. warning:: The LU decomposition is almost never unique, as often there are different permutation | 
|  | matrices that can yield different LU decompositions. | 
|  | As such, different platforms, like SciPy, or inputs on different devices, | 
|  | may produce different valid decompositions. | 
|  |  | 
|  | .. warning:: Gradient computations are only supported if the input matrix is full-rank. | 
|  | If this condition is not met, no error will be thrown, but the gradient | 
|  | may not be finite. | 
|  | This is because the LU decomposition with pivoting is not differentiable at these points. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting or | 
|  | no pivoting. Default: `True`. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): output tuple of three tensors. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(P, L, U)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.randn(3, 2) | 
|  | >>> P, L, U = torch.linalg.lu(A) | 
|  | >>> P | 
|  | tensor([[0., 1., 0.], | 
|  | [0., 0., 1.], | 
|  | [1., 0., 0.]]) | 
|  | >>> L | 
|  | tensor([[1.0000, 0.0000], | 
|  | [0.5007, 1.0000], | 
|  | [0.0633, 0.9755]]) | 
|  | >>> U | 
|  | tensor([[0.3771, 0.0489], | 
|  | [0.0000, 0.9644]]) | 
|  | >>> torch.dist(A, P @ L @ U) | 
|  | tensor(5.9605e-08) | 
|  |  | 
|  | >>> A = torch.randn(2, 5, 7, device="cuda") | 
|  | >>> P, L, U = torch.linalg.lu(A, pivot=False) | 
|  | >>> P | 
|  | tensor([], device='cuda:0') | 
|  | >>> torch.dist(A, L @ U) | 
|  | tensor(1.0376e-06, device='cuda:0') | 
|  |  | 
|  | .. _permutation matrix: | 
|  | https://en.wikipedia.org/wiki/Permutation_matrix | 
|  | .. _may not exist: | 
|  | https://en.wikipedia.org/wiki/LU_decomposition#Definitions | 
|  | """) | 
|  |  | 
|  | tensorinv = _add_docstr(_linalg.linalg_tensorinv, r""" | 
|  | linalg.tensorinv(A, ind=2, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the multiplicative inverse of :func:`torch.tensordot`. | 
|  |  | 
|  | If `m` is the product of the first :attr:`ind` dimensions of :attr:`A` and `n` is the product of | 
|  | the rest of the dimensions, this function expects `m` and `n` to be equal. | 
|  | If this is the case, it computes a tensor `X` such that | 
|  | `tensordot(\ `:attr:`A`\ `, X, \ `:attr:`ind`\ `)` is the identity matrix in dimension `m`. | 
|  | `X` will have the shape of :attr:`A` but with the first :attr:`ind` dimensions pushed back to the end | 
|  |  | 
|  | .. code:: text | 
|  |  | 
|  | X.shape == A.shape[ind:] + A.shape[:ind] | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  |  | 
|  | .. note:: When :attr:`A` is a `2`-dimensional tensor and :attr:`ind`\ `= 1`, | 
|  | this function computes the (multiplicative) inverse of :attr:`A` | 
|  | (see :func:`torch.linalg.inv`). | 
|  |  | 
|  | .. note:: | 
|  | Consider using :func:`torch.linalg.tensorsolve` if possible for multiplying a tensor on the left | 
|  | by the tensor inverse, as:: | 
|  |  | 
|  | linalg.tensorsolve(A, B) == torch.tensordot(linalg.tensorinv(A), B)  # When B is a tensor with shape A.shape[:B.ndim] | 
|  |  | 
|  | It is always prefered to use :func:`~tensorsolve` when possible, as it is faster and more | 
|  | numerically stable than computing the pseudoinverse explicitly. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.tensorsolve` computes | 
|  | `torch.tensordot(tensorinv(\ `:attr:`A`\ `), \ `:attr:`B`\ `)`. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor to invert. Its shape must satisfy | 
|  | `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`ind`\ `]) == | 
|  | prod(\ `:attr:`A`\ `.shape[\ `:attr:`ind`\ `:])`. | 
|  | ind (int): index at which to compute the inverse of :func:`torch.tensordot`. Default: `2`. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the reshaped :attr:`A` is not invertible or the product of the first | 
|  | :attr:`ind` dimensions is not equal to the product of the rest. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.eye(4 * 6).reshape((4, 6, 8, 3)) | 
|  | >>> Ainv = torch.linalg.tensorinv(A, ind=2) | 
|  | >>> Ainv.shape | 
|  | torch.Size([8, 3, 4, 6]) | 
|  | >>> B = torch.randn(4, 6) | 
|  | >>> torch.allclose(torch.tensordot(Ainv, B), torch.linalg.tensorsolve(A, B)) | 
|  | True | 
|  |  | 
|  | >>> A = torch.randn(4, 4) | 
|  | >>> Atensorinv = torch.linalg.tensorinv(A, ind=1) | 
|  | >>> Ainv = torch.linalg.inverse(A) | 
|  | >>> torch.allclose(Atensorinv, Ainv) | 
|  | True | 
|  | """) | 
|  |  | 
|  | tensorsolve = _add_docstr(_linalg.linalg_tensorsolve, r""" | 
|  | linalg.tensorsolve(A, B, dims=None, *, out=None) -> Tensor | 
|  |  | 
|  | Computes the solution `X` to the system `torch.tensordot(A, X) = B`. | 
|  |  | 
|  | If `m` is the product of the first :attr:`B`\ `.ndim`  dimensions of :attr:`A` and | 
|  | `n` is the product of the rest of the dimensions, this function expects `m` and `n` to be equal. | 
|  |  | 
|  | The returned tensor `x` satisfies | 
|  | `tensordot(\ `:attr:`A`\ `, x, dims=x.ndim) == \ `:attr:`B`. | 
|  | `x` has shape :attr:`A`\ `[B.ndim:]`. | 
|  |  | 
|  | If :attr:`dims` is specified, :attr:`A` will be reshaped as | 
|  |  | 
|  | .. code:: text | 
|  |  | 
|  | A = movedim(A, dims, range(len(dims) - A.ndim + 1, 0)) | 
|  |  | 
|  | Supports inputs of float, double, cfloat and cdouble dtypes. | 
|  |  | 
|  | .. seealso:: | 
|  |  | 
|  | :func:`torch.linalg.tensorinv` computes the multiplicative inverse of | 
|  | :func:`torch.tensordot`. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor to solve for. Its shape must satisfy | 
|  | `prod(\ `:attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]) == | 
|  | prod(\ `:attr:`A`\ `.shape[\ `:attr:`B`\ `.ndim:])`. | 
|  | B (Tensor): tensor of shape :attr:`A`\ `.shape[:\ `:attr:`B`\ `.ndim]`. | 
|  | dims (Tuple[int], optional): dimensions of :attr:`A` to be moved. | 
|  | If `None`, no dimensions are moved. Default: `None`. | 
|  |  | 
|  | Keyword args: | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Raises: | 
|  | RuntimeError: if the reshaped :attr:`A`\ `.view(m, m)` with `m` as above  is not | 
|  | invertible or the product of the first :attr:`ind` dimensions is not equal | 
|  | to the product of the rest of the dimensions. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.eye(2 * 3 * 4).reshape((2 * 3, 4, 2, 3, 4)) | 
|  | >>> B = torch.randn(2 * 3, 4) | 
|  | >>> X = torch.linalg.tensorsolve(A, B) | 
|  | >>> X.shape | 
|  | torch.Size([2, 3, 4]) | 
|  | >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B) | 
|  | True | 
|  |  | 
|  | >>> A = torch.randn(6, 4, 4, 3, 2) | 
|  | >>> B = torch.randn(4, 3, 2) | 
|  | >>> X = torch.linalg.tensorsolve(A, B, dims=(0, 2)) | 
|  | >>> X.shape | 
|  | torch.Size([6, 4]) | 
|  | >>> A = A.permute(1, 3, 4, 0, 2) | 
|  | >>> A.shape[B.ndim:] | 
|  | torch.Size([6, 4]) | 
|  | >>> torch.allclose(torch.tensordot(A, X, dims=X.ndim), B, atol=1e-6) | 
|  | True | 
|  | """) | 
|  |  | 
|  | qr = _add_docstr(_linalg.linalg_qr, r""" | 
|  | qr(A, mode='reduced', *, out=None) -> (Tensor, Tensor) | 
|  |  | 
|  | Computes the QR decomposition of a matrix. | 
|  |  | 
|  | Letting :math:`\mathbb{K}` be :math:`\mathbb{R}` or :math:`\mathbb{C}`, | 
|  | the **full QR decomposition** of a matrix | 
|  | :math:`A \in \mathbb{K}^{m \times n}` is defined as | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times m}, R \in \mathbb{K}^{m \times n}} | 
|  |  | 
|  | where :math:`Q` is orthogonal in the real case and unitary in the complex case, | 
|  | and :math:`R` is upper triangular with real diagonal (even in the complex case). | 
|  |  | 
|  | When `m > n` (tall matrix), as `R` is upper triangular, its last `m - n` rows are zero. | 
|  | In this case, we can drop the last `m - n` columns of `Q` to form the | 
|  | **reduced QR decomposition**: | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | A = QR\mathrlap{\qquad Q \in \mathbb{K}^{m \times n}, R \in \mathbb{K}^{n \times n}} | 
|  |  | 
|  | The reduced QR decomposition agrees with the full QR decomposition when `n >= m` (wide matrix). | 
|  |  | 
|  | Supports input of float, double, cfloat and cdouble dtypes. | 
|  | Also supports batches of matrices, and if :attr:`A` is a batch of matrices then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | The parameter :attr:`mode` chooses between the full and reduced QR decomposition. | 
|  | If :attr:`A` has shape `(*, m, n)`, denoting `k = min(m, n)` | 
|  |  | 
|  | - :attr:`mode`\ `= 'reduced'` (default): Returns `(Q, R)` of shapes `(*, m, k)`, `(*, k, n)` respectively. | 
|  | It is always differentiable. | 
|  | - :attr:`mode`\ `= 'complete'`: Returns `(Q, R)` of shapes `(*, m, m)`, `(*, m, n)` respectively. | 
|  | It is differentiable for `m <= n`. | 
|  | - :attr:`mode`\ `= 'r'`: Computes only the reduced `R`. Returns `(Q, R)` with `Q` empty and `R` of shape `(*, k, n)`. | 
|  | It is never differentiable. | 
|  |  | 
|  | Differences with `numpy.linalg.qr`: | 
|  |  | 
|  | - :attr:`mode`\ `= 'raw'` is not implemented. | 
|  | - Unlike `numpy.linalg.qr`, this function always returns a tuple of two tensors. | 
|  | When :attr:`mode`\ `= 'r'`, the `Q` tensor is an empty tensor. | 
|  |  | 
|  | .. warning:: The elements in the diagonal of `R` are not necessarily positive. | 
|  | As such, the returned QR decomposition is only unique up to the sign of the diagonal of `R`. | 
|  | Therefore, different platforms, like NumPy, or inputs on different devices, | 
|  | may produce different valid decompositions. | 
|  |  | 
|  | .. warning:: The QR decomposition is only well-defined if the first `k = min(m, n)` columns | 
|  | of every matrix in :attr:`A` are linearly independent. | 
|  | If this condition is not met, no error will be thrown, but the QR produced | 
|  | may be incorrect and its autodiff may fail or produce incorrect results. | 
|  |  | 
|  | Args: | 
|  | A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. | 
|  | mode (str, optional): one of `'reduced'`, `'complete'`, `'r'`. | 
|  | Controls the shape of the returned tensors. Default: `'reduced'`. | 
|  |  | 
|  | Keyword args: | 
|  | out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Returns: | 
|  | A named tuple `(Q, R)`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> A = torch.tensor([[12., -51, 4], [6, 167, -68], [-4, 24, -41]]) | 
|  | >>> Q, R = torch.linalg.qr(A) | 
|  | >>> Q | 
|  | tensor([[-0.8571,  0.3943,  0.3314], | 
|  | [-0.4286, -0.9029, -0.0343], | 
|  | [ 0.2857, -0.1714,  0.9429]]) | 
|  | >>> R | 
|  | tensor([[ -14.0000,  -21.0000,   14.0000], | 
|  | [   0.0000, -175.0000,   70.0000], | 
|  | [   0.0000,    0.0000,  -35.0000]]) | 
|  | >>> (Q @ R).round() | 
|  | tensor([[  12.,  -51.,    4.], | 
|  | [   6.,  167.,  -68.], | 
|  | [  -4.,   24.,  -41.]]) | 
|  | >>> (Q.T @ Q).round() | 
|  | tensor([[ 1.,  0.,  0.], | 
|  | [ 0.,  1., -0.], | 
|  | [ 0., -0.,  1.]]) | 
|  | >>> Q2, R2 = torch.linalg.qr(A, mode='r') | 
|  | >>> Q2 | 
|  | tensor([]) | 
|  | >>> torch.equal(R, R2) | 
|  | True | 
|  | >>> A = torch.randn(3, 4, 5) | 
|  | >>> Q, R = torch.linalg.qr(A, mode='complete') | 
|  | >>> torch.dist(Q @ R, A) | 
|  | tensor(1.6099e-06) | 
|  | >>> torch.dist(Q.mT @ Q, torch.eye(4)) | 
|  | tensor(6.2158e-07) | 
|  | """) | 
|  |  | 
|  | vander = _add_docstr(_linalg.linalg_vander, r""" | 
|  | vander(x, N=None) -> Tensor | 
|  |  | 
|  | Generates a Vandermonde matrix. | 
|  |  | 
|  | Returns the Vandermonde matrix :math:`V` | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | V = \begin{pmatrix} | 
|  | 1 & x_1 & x_1^2 & \dots & x_1^{N-1}\\ | 
|  | 1 & x_2 & x_2^2 & \dots & x_2^{N-1}\\ | 
|  | 1 & x_3 & x_3^2 & \dots & x_3^{N-1}\\ | 
|  | \vdots & \vdots & \vdots & \ddots &\vdots \\ | 
|  | 1 & x_n & x_n^2 & \dots & x_n^{N-1} | 
|  | \end{pmatrix}. | 
|  |  | 
|  | for `N > 1`. | 
|  | If :attr:`N`\ `= None`, then `N = x.size(-1)` so that the output is a square matrix. | 
|  |  | 
|  | Supports inputs of float, double, cfloat, cdouble, and integral dtypes. | 
|  | Also supports batches of vectors, and if :attr:`x` is a batch of vectors then | 
|  | the output has the same batch dimensions. | 
|  |  | 
|  | Differences with `numpy.vander`: | 
|  |  | 
|  | - Unlike `numpy.vander`, this function returns the powers of :attr:`x` in ascending order. | 
|  | To get them in the reverse order call ``linalg.vander(x, N).flip(-1)``. | 
|  |  | 
|  | Args: | 
|  | x (Tensor): tensor of shape `(*, n)` where `*` is zero or more batch dimensions | 
|  | consisting of vectors. | 
|  |  | 
|  | Keyword args: | 
|  | N (int, optional): Number of columns in the output. Default: `x.size(-1)` | 
|  |  | 
|  | Example:: | 
|  |  | 
|  | >>> x = torch.tensor([1, 2, 3, 5]) | 
|  | >>> linalg.vander(x) | 
|  | tensor([[  1,   1,   1,   1], | 
|  | [  1,   2,   4,   8], | 
|  | [  1,   3,   9,  27], | 
|  | [  1,   5,  25, 125]]) | 
|  | >>> linalg.vander(x, N=3) | 
|  | tensor([[ 1,  1,  1], | 
|  | [ 1,  2,  4], | 
|  | [ 1,  3,  9], | 
|  | [ 1,  5, 25]]) | 
|  | """) | 
|  |  | 
|  | vecdot = _add_docstr(_linalg.linalg_vecdot, r""" | 
|  | linalg.vecdot(x, y, *, dim=-1, out=None) -> Tensor | 
|  |  | 
|  | Computes the dot product of two batches of vectors along a dimension. | 
|  |  | 
|  | In symbols, this function computes | 
|  |  | 
|  | .. math:: | 
|  |  | 
|  | \sum_{i=1}^n \overline{x_i}y_i. | 
|  |  | 
|  | over the dimension :attr:`dim` where :math:`\overline{x_i}` denotes the conjugate for complex | 
|  | vectors, and it is the identity for real vectors. | 
|  |  | 
|  | Supports input of half, bfloat16, float, double, cfloat, cdouble and integral dtypes. | 
|  | It also supports broadcasting. | 
|  |  | 
|  | Args: | 
|  | x (Tensor): first batch of vectors of shape `(*, n)`. | 
|  | y (Tensor): second batch of vectors of shape `(*, n)`. | 
|  |  | 
|  | Keyword args: | 
|  | dim (int): Dimension along which to compute the dot product. Default: `-1`. | 
|  | out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`. | 
|  |  | 
|  | Examples:: | 
|  |  | 
|  | >>> v1 = torch.randn(3, 2) | 
|  | >>> v2 = torch.randn(3, 2) | 
|  | >>> linalg.vecdot(v1, v2) | 
|  | tensor([ 0.3223,  0.2815, -0.1944]) | 
|  | >>> torch.vdot(v1[0], v2[0]) | 
|  | tensor(0.3223) | 
|  | """) |