| """Various linear algebra utility methods for internal use. | 
 |  | 
 | """ | 
 |  | 
 | from typing import Optional, Tuple | 
 |  | 
 | import torch | 
 | from torch import Tensor | 
 |  | 
 |  | 
 | def is_sparse(A): | 
 |     """Check if tensor A is a sparse tensor""" | 
 |     if isinstance(A, torch.Tensor): | 
 |         return A.layout == torch.sparse_coo | 
 |  | 
 |     error_str = "expected Tensor" | 
 |     if not torch.jit.is_scripting(): | 
 |         error_str += " but got {}".format(type(A)) | 
 |     raise TypeError(error_str) | 
 |  | 
 |  | 
 | def get_floating_dtype(A): | 
 |     """Return the floating point dtype of tensor A. | 
 |  | 
 |     Integer types map to float32. | 
 |     """ | 
 |     dtype = A.dtype | 
 |     if dtype in (torch.float16, torch.float32, torch.float64): | 
 |         return dtype | 
 |     return torch.float32 | 
 |  | 
 |  | 
 | def matmul(A: Optional[Tensor], B: Tensor) -> Tensor: | 
 |     """Multiply two matrices. | 
 |  | 
 |     If A is None, return B. A can be sparse or dense. B is always | 
 |     dense. | 
 |     """ | 
 |     if A is None: | 
 |         return B | 
 |     if is_sparse(A): | 
 |         return torch.sparse.mm(A, B) | 
 |     return torch.matmul(A, B) | 
 |  | 
 |  | 
 | def conjugate(A): | 
 |     """Return conjugate of tensor A. | 
 |  | 
 |     .. note:: If A's dtype is not complex, A is returned. | 
 |     """ | 
 |     if A.is_complex(): | 
 |         return A.conj() | 
 |     return A | 
 |  | 
 |  | 
 | def transpose(A): | 
 |     """Return transpose of a matrix or batches of matrices.""" | 
 |     ndim = len(A.shape) | 
 |     return A.transpose(ndim - 1, ndim - 2) | 
 |  | 
 |  | 
 | def transjugate(A): | 
 |     """Return transpose conjugate of a matrix or batches of matrices.""" | 
 |     return conjugate(transpose(A)) | 
 |  | 
 |  | 
 | def bform(X: Tensor, A: Optional[Tensor], Y: Tensor) -> Tensor: | 
 |     """Return bilinear form of matrices: :math:`X^T A Y`.""" | 
 |     return matmul(transpose(X), matmul(A, Y)) | 
 |  | 
 |  | 
 | def qform(A: Optional[Tensor], S: Tensor): | 
 |     """Return quadratic form :math:`S^T A S`.""" | 
 |     return bform(S, A, S) | 
 |  | 
 |  | 
 | def basis(A): | 
 |     """Return orthogonal basis of A columns.""" | 
 |     return torch.linalg.qr(A).Q | 
 |  | 
 |  | 
 | def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: | 
 |     """Return eigenpairs of A with specified ordering.""" | 
 |     if largest is None: | 
 |         largest = False | 
 |     E, Z = torch.linalg.eigh(A, UPLO="U") | 
 |     # assuming that E is ordered | 
 |     if largest: | 
 |         E = torch.flip(E, dims=(-1,)) | 
 |         Z = torch.flip(Z, dims=(-1,)) | 
 |     return E, Z | 
 |  | 
 |  | 
 | # These functions were deprecated and removed | 
 | # This nice error message can be removed in version 1.13+ | 
 | def matrix_rank(input, tol=None, symmetric=False, *, out=None) -> Tensor: | 
 |     raise RuntimeError( | 
 |         "This function was deprecated since version 1.9 and is now removed.", | 
 |         "Please use the `torch.linalg.matrix_rank` function instead.", | 
 |     ) | 
 |  | 
 |  | 
 | def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: | 
 |     raise RuntimeError( | 
 |         "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.", | 
 |     ) | 
 |  | 
 |  | 
 | def lstsq(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: | 
 |     raise RuntimeError( | 
 |         "This function was deprecated since version 1.9 and is now removed.", | 
 |         "Please use the `torch.linalg.lstsq` function instead.", | 
 |     ) | 
 |  | 
 |  | 
 | def _symeig( | 
 |     input, eigenvectors=False, upper=True, *, out=None | 
 | ) -> Tuple[Tensor, Tensor]: | 
 |     raise RuntimeError( | 
 |         "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.eigh` function instead.", | 
 |     ) | 
 |  | 
 |  | 
 | def eig( | 
 |     self: Tensor, eigenvectors: bool = False, *, e=None, v=None | 
 | ) -> Tuple[Tensor, Tensor]: | 
 |     raise RuntimeError( | 
 |         "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.eig` function instead.", | 
 |     ) |