|
"""Implement various linear algebra algorithms for low rank matrices.""" |
|
|
|
__all__ = ["svd_lowrank", "pca_lowrank"] |
|
|
|
from typing import Optional |
|
|
|
import torch |
|
from torch import _linalg_utils as _utils, Tensor |
|
from torch.overrides import handle_torch_function, has_torch_function |
|
|
|
|
|
def get_approximate_basis( |
|
A: Tensor, |
|
q: int, |
|
niter: Optional[int] = 2, |
|
M: Optional[Tensor] = None, |
|
) -> Tensor: |
|
"""Return tensor :math:`Q` with :math:`q` orthonormal columns such |
|
that :math:`Q Q^H A` approximates :math:`A`. If :math:`M` is |
|
specified, then :math:`Q` is such that :math:`Q Q^H (A - M)` |
|
approximates :math:`A - M`. without instantiating any tensors |
|
of the size of :math:`A` or :math:`M`. |
|
|
|
.. note:: The implementation is based on the Algorithm 4.4 from |
|
Halko et al., 2009. |
|
|
|
.. note:: For an adequate approximation of a k-rank matrix |
|
:math:`A`, where k is not known in advance but could be |
|
estimated, the number of :math:`Q` columns, q, can be |
|
choosen according to the following criteria: in general, |
|
:math:`k <= q <= min(2*k, m, n)`. For large low-rank |
|
matrices, take :math:`q = k + 5..10`. If k is |
|
relatively small compared to :math:`min(m, n)`, choosing |
|
:math:`q = k + 0..2` may be sufficient. |
|
|
|
.. note:: To obtain repeatable results, reset the seed for the |
|
pseudorandom number generator |
|
|
|
Args:: |
|
A (Tensor): the input tensor of size :math:`(*, m, n)` |
|
|
|
q (int): the dimension of subspace spanned by :math:`Q` |
|
columns. |
|
|
|
niter (int, optional): the number of subspace iterations to |
|
conduct; ``niter`` must be a |
|
nonnegative integer. In most cases, the |
|
default value 2 is more than enough. |
|
|
|
M (Tensor, optional): the input tensor's mean of size |
|
:math:`(*, m, n)`. |
|
|
|
References:: |
|
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding |
|
structure with randomness: probabilistic algorithms for |
|
constructing approximate matrix decompositions, |
|
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at |
|
`arXiv <http://arxiv.org/abs/0909.4061>`_). |
|
""" |
|
|
|
niter = 2 if niter is None else niter |
|
dtype = _utils.get_floating_dtype(A) if not A.is_complex() else A.dtype |
|
matmul = _utils.matmul |
|
|
|
R = torch.randn(A.shape[-1], q, dtype=dtype, device=A.device) |
|
|
|
|
|
|
|
|
|
X = matmul(A, R) |
|
if M is not None: |
|
X = X - matmul(M, R) |
|
Q = torch.linalg.qr(X).Q |
|
for _ in range(niter): |
|
X = matmul(A.mH, Q) |
|
if M is not None: |
|
X = X - matmul(M.mH, Q) |
|
Q = torch.linalg.qr(X).Q |
|
X = matmul(A, Q) |
|
if M is not None: |
|
X = X - matmul(M, Q) |
|
Q = torch.linalg.qr(X).Q |
|
return Q |
|
|
|
|
|
def svd_lowrank( |
|
A: Tensor, |
|
q: Optional[int] = 6, |
|
niter: Optional[int] = 2, |
|
M: Optional[Tensor] = None, |
|
) -> tuple[Tensor, Tensor, Tensor]: |
|
r"""Return the singular value decomposition ``(U, S, V)`` of a matrix, |
|
batches of matrices, or a sparse matrix :math:`A` such that |
|
:math:`A \approx U \operatorname{diag}(S) V^{\text{H}}`. In case :math:`M` is given, then |
|
SVD is computed for the matrix :math:`A - M`. |
|
|
|
.. note:: The implementation is based on the Algorithm 5.1 from |
|
Halko et al., 2009. |
|
|
|
.. note:: For an adequate approximation of a k-rank matrix |
|
:math:`A`, where k is not known in advance but could be |
|
estimated, the number of :math:`Q` columns, q, can be |
|
choosen according to the following criteria: in general, |
|
:math:`k <= q <= min(2*k, m, n)`. For large low-rank |
|
matrices, take :math:`q = k + 5..10`. If k is |
|
relatively small compared to :math:`min(m, n)`, choosing |
|
:math:`q = k + 0..2` may be sufficient. |
|
|
|
.. note:: This is a randomized method. To obtain repeatable results, |
|
set the seed for the pseudorandom number generator |
|
|
|
.. note:: In general, use the full-rank SVD implementation |
|
:func:`torch.linalg.svd` for dense matrices due to its 10x |
|
higher performance characteristics. The low-rank SVD |
|
will be useful for huge sparse matrices that |
|
:func:`torch.linalg.svd` cannot handle. |
|
|
|
Args:: |
|
A (Tensor): the input tensor of size :math:`(*, m, n)` |
|
|
|
q (int, optional): a slightly overestimated rank of A. |
|
|
|
niter (int, optional): the number of subspace iterations to |
|
conduct; niter must be a nonnegative |
|
integer, and defaults to 2 |
|
|
|
M (Tensor, optional): the input tensor's mean of size |
|
:math:`(*, m, n)`, which will be broadcasted |
|
to the size of A in this function. |
|
|
|
References:: |
|
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding |
|
structure with randomness: probabilistic algorithms for |
|
constructing approximate matrix decompositions, |
|
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at |
|
`arXiv <https://arxiv.org/abs/0909.4061>`_). |
|
|
|
""" |
|
if not torch.jit.is_scripting(): |
|
tensor_ops = (A, M) |
|
if not set(map(type, tensor_ops)).issubset( |
|
(torch.Tensor, type(None)) |
|
) and has_torch_function(tensor_ops): |
|
return handle_torch_function( |
|
svd_lowrank, tensor_ops, A, q=q, niter=niter, M=M |
|
) |
|
return _svd_lowrank(A, q=q, niter=niter, M=M) |
|
|
|
|
|
def _svd_lowrank( |
|
A: Tensor, |
|
q: Optional[int] = 6, |
|
niter: Optional[int] = 2, |
|
M: Optional[Tensor] = None, |
|
) -> tuple[Tensor, Tensor, Tensor]: |
|
|
|
|
|
q = 6 if q is None else q |
|
m, n = A.shape[-2:] |
|
matmul = _utils.matmul |
|
if M is not None: |
|
M = M.broadcast_to(A.size()) |
|
|
|
|
|
if m < n: |
|
A = A.mH |
|
if M is not None: |
|
M = M.mH |
|
|
|
Q = get_approximate_basis(A, q, niter=niter, M=M) |
|
B = matmul(Q.mH, A) |
|
if M is not None: |
|
B = B - matmul(Q.mH, M) |
|
U, S, Vh = torch.linalg.svd(B, full_matrices=False) |
|
V = Vh.mH |
|
U = Q.matmul(U) |
|
|
|
if m < n: |
|
U, V = V, U |
|
|
|
return U, S, V |
|
|
|
|
|
def pca_lowrank( |
|
A: Tensor, |
|
q: Optional[int] = None, |
|
center: bool = True, |
|
niter: int = 2, |
|
) -> tuple[Tensor, Tensor, Tensor]: |
|
r"""Performs linear Principal Component Analysis (PCA) on a low-rank |
|
matrix, batches of such matrices, or sparse matrix. |
|
|
|
This function returns a namedtuple ``(U, S, V)`` which is the |
|
nearly optimal approximation of a singular value decomposition of |
|
a centered matrix :math:`A` such that :math:`A \approx U \operatorname{diag}(S) V^{\text{H}}` |
|
|
|
.. note:: The relation of ``(U, S, V)`` to PCA is as follows: |
|
|
|
- :math:`A` is a data matrix with ``m`` samples and |
|
``n`` features |
|
|
|
- the :math:`V` columns represent the principal directions |
|
|
|
- :math:`S ** 2 / (m - 1)` contains the eigenvalues of |
|
:math:`A^T A / (m - 1)` which is the covariance of |
|
``A`` when ``center=True`` is provided. |
|
|
|
- ``matmul(A, V[:, :k])`` projects data to the first k |
|
principal components |
|
|
|
.. note:: Different from the standard SVD, the size of returned |
|
matrices depend on the specified rank and q |
|
values as follows: |
|
|
|
- :math:`U` is m x q matrix |
|
|
|
- :math:`S` is q-vector |
|
|
|
- :math:`V` is n x q matrix |
|
|
|
.. note:: To obtain repeatable results, reset the seed for the |
|
pseudorandom number generator |
|
|
|
Args: |
|
|
|
A (Tensor): the input tensor of size :math:`(*, m, n)` |
|
|
|
q (int, optional): a slightly overestimated rank of |
|
:math:`A`. By default, ``q = min(6, m, |
|
n)``. |
|
|
|
center (bool, optional): if True, center the input tensor, |
|
otherwise, assume that the input is |
|
centered. |
|
|
|
niter (int, optional): the number of subspace iterations to |
|
conduct; niter must be a nonnegative |
|
integer, and defaults to 2. |
|
|
|
References:: |
|
|
|
- Nathan Halko, Per-Gunnar Martinsson, and Joel Tropp, Finding |
|
structure with randomness: probabilistic algorithms for |
|
constructing approximate matrix decompositions, |
|
arXiv:0909.4061 [math.NA; math.PR], 2009 (available at |
|
`arXiv <http://arxiv.org/abs/0909.4061>`_). |
|
|
|
""" |
|
|
|
if not torch.jit.is_scripting(): |
|
if type(A) is not torch.Tensor and has_torch_function((A,)): |
|
return handle_torch_function( |
|
pca_lowrank, (A,), A, q=q, center=center, niter=niter |
|
) |
|
|
|
(m, n) = A.shape[-2:] |
|
|
|
if q is None: |
|
q = min(6, m, n) |
|
elif not (q >= 0 and q <= min(m, n)): |
|
raise ValueError( |
|
f"q(={q}) must be non-negative integer and not greater than min(m, n)={min(m, n)}" |
|
) |
|
if not (niter >= 0): |
|
raise ValueError(f"niter(={niter}) must be non-negative integer") |
|
|
|
dtype = _utils.get_floating_dtype(A) |
|
|
|
if not center: |
|
return _svd_lowrank(A, q, niter=niter, M=None) |
|
|
|
if _utils.is_sparse(A): |
|
if len(A.shape) != 2: |
|
raise ValueError("pca_lowrank input is expected to be 2-dimensional tensor") |
|
c = torch.sparse.sum(A, dim=(-2,)) / m |
|
|
|
column_indices = c.indices()[0] |
|
indices = torch.zeros( |
|
2, |
|
len(column_indices), |
|
dtype=column_indices.dtype, |
|
device=column_indices.device, |
|
) |
|
indices[0] = column_indices |
|
C_t = torch.sparse_coo_tensor( |
|
indices, c.values(), (n, 1), dtype=dtype, device=A.device |
|
) |
|
|
|
ones_m1_t = torch.ones(A.shape[:-2] + (1, m), dtype=dtype, device=A.device) |
|
M = torch.sparse.mm(C_t, ones_m1_t).mT |
|
return _svd_lowrank(A, q, niter=niter, M=M) |
|
else: |
|
C = A.mean(dim=(-2,), keepdim=True) |
|
return _svd_lowrank(A - C, q, niter=niter, M=None) |
|
|