Skip to content

Commit

Permalink
[ADD] Implement Hutch++ trace estimation (#39)
Browse files Browse the repository at this point in the history
* [ADD] Implement Hutch++

* [ADD] Test Hutch++

* [DOC] Specify comparability with Hutchinson

* [FIX] Bug: Use span of `A @ S`, not span of `S`
  • Loading branch information
f-dangel authored Oct 16, 2023
1 parent 1b91227 commit 42913be
Show file tree
Hide file tree
Showing 5 changed files with 207 additions and 1 deletion.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator
from curvlinops.trace.meyer2020hutch import HutchPPTraceEstimator

__all__ = [
"HessianLinearOperator",
Expand All @@ -30,4 +31,5 @@
"LanczosApproximateSpectrumCached",
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
"HutchPPTraceEstimator",
]
158 changes: 158 additions & 0 deletions curvlinops/trace/meyer2020hutch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Implementation of Hutch++ trace estimation from Meyer et al."""

from typing import Callable, Dict, Optional, Union

from numpy import column_stack, dot, ndarray
from numpy.linalg import qr
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher


class HutchPPTraceEstimator:
"""Class to perform trace estimation with the Huch++ method.
In contrast to vanilla Hutchinson, Hutch++ has lower variance, but requires more
memory.
For details, see
- Meyer, R. A., Musco, C., Musco, C., & Woodruff, D. P. (2020). Hutch++:
optimal stochastic trace estimation.
Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchPPTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(998)])
>>> # assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 2.4085, 4.5791)
Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(
self,
A: LinearOperator,
basis_dim: Optional[int] = None,
basis_distribution: str = "rademacher",
):
"""Store the linear operator whose trace will be estimated.
Args:
A: Linear square-shaped operator whose trace will be estimated.
basis_dim: Dimension of the subspace used for exact trace estimation.
Can be at most the linear operator's dimension. By default, its
size will be 1% of the matrix's dimension, but at most ``10``.
This assumes that we are working with very large matrices and we can
only afford storing a small number of columns at a time.
basis_distribution: Distribution of the vectors used to construct the
subspace. Either ``'rademacher'` or ``'normal'``. Default is
``'rademacher'``.
Raises:
ValueError: If the operator is not square, the basis dimension is too
large, or the sampling distribution is not supported.
Note:
If you are planning to perform a fair (i.e. same computation budget)
comparison with vanilla Hutchinson, ``basis_dim`` should be ``s / 3``
where ``s`` is the number of samples used by vanilla Hutchinson. If
``s / 3`` requires storing a too large matrix, you can pick
``basis_dim = s1`` and draw ``s2`` samples from Hutch++ such that
``2 * s1 + s2 = s``.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

dim = A.shape[1]
basis_dim = basis_dim if basis_dim is not None else min(10, max(dim // 100, 1))
if basis_dim > self._A.shape[1]:
raise ValueError(
f"Basis dimension must be at most {self._A.shape[1]}. Got {basis_dim}."
)
self._basis_dim = basis_dim

if basis_distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {basis_distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)
self._basis_distribution = basis_distribution

# When drawing the first sample, the basis and its subspace trace will be
# computed and stored in the following buffers for further samples
self._Q: Union[ndarray, None] = None
self._tr_QT_A_Q: Union[float, None] = None

def sample(self, distribution: str = "rademacher") -> float:
"""Draw a sample from the trace estimator.
Multiple samples can be combined into a more accurate trace estimation via
averaging.
Note:
Calling this function for the first time will also compute the sub-space and
its trace. Future calls will be faster as the latter are cached internally.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
"""
self.maybe_compute_and_cache_subspace()

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

dim = self._A.shape[1]
v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
# project out subspace
v -= self._Q @ (self._Q.T @ v)

Av = self._A @ v

return self._tr_QT_A_Q + dot(v, Av)

def maybe_compute_and_cache_subspace(self):
"""Compute and cache the subspace and its trace if not already done."""
if self._Q is not None and self._tr_QT_A_Q is not None:
return

dim = self._A.shape[1]
AS = column_stack(
[
self._A @ self.SUPPORTED_DISTRIBUTIONS[self._basis_distribution](dim)
for _ in range(self._basis_dim)
]
)
self._Q, _ = qr(AS)

self._tr_QT_A_Q = 0.0
for i in range(self._basis_dim):
v = self._Q[:, i]
Av = self._A @ v
self._tr_QT_A_Q += dot(v, Av)
3 changes: 3 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,6 @@ Trace approximation

.. autoclass:: curvlinops.HutchinsonTraceEstimator
:members: __init__, sample

.. autoclass:: curvlinops.HutchPPTraceEstimator
:members: __init__, sample
2 changes: 1 addition & 1 deletion test/trace/test_hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Test ``curvlinops.trace.__init__``."""
"""Test ``curvlinops.trace.hutchinson``."""

from numpy import isclose, mean, trace
from numpy.random import rand, seed
Expand Down
43 changes: 43 additions & 0 deletions test/trace/test_meyer2020hutch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test ``curvlinops.trace.meyer2020hutch."""

from numpy import isclose, mean, trace
from numpy.random import rand, seed
from pytest import mark

from curvlinops import HutchPPTraceEstimator

DISTRIBUTIONS = ["rademacher", "normal"]
DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS]


@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_HutchPPTraceEstimator(distribution: str):
"""Test whether Hutch++'s trace estimator converges to the true trace.
Args:
distribution: Distribution of the random vectors used for the trace estimation.
"""
seed(0)
A = rand(10, 10)
tr_A = trace(A)

samples = []
max_samples = 20_000
chunk_size = 2_000 # add that many new samples before comparing against the truth
atol, rtol = 1e-3, 1e-2

estimator = HutchPPTraceEstimator(A)

while len(samples) < max_samples:
samples.extend(
[estimator.sample(distribution=distribution) for _ in range(chunk_size)]
)
tr_estimator = mean(samples)
if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol):
print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}{tr_estimator:.5f}.")
else:
# quit once the estimator has converged
break

tr_estimator = mean(samples)
assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol)

0 comments on commit 42913be

Please sign in to comment.