Skip to content

Commit

Permalink
[ADD] Hutchinson trace estimator
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Oct 16, 2023
1 parent fb6ba25 commit 0cb608a
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 2 deletions.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
lanczos_approximate_spectrum,
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator

__all__ = [
"HessianLinearOperator",
Expand All @@ -28,4 +29,5 @@
"lanczos_approximate_log_spectrum",
"LanczosApproximateSpectrumCached",
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
]
1 change: 1 addition & 0 deletions curvlinops/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Trace estimation techniques."""
85 changes: 85 additions & 0 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Vanilla Hutchinson trace estimation."""

from typing import Callable, Dict

from numpy import dot, ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher


class HutchinsonTraceEstimator:
"""Class to perform trace estimation with Hutchinson's method.
For details, see
- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.
Attributes:
SUPPORTED_SAMPLINGS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_SAMPLINGS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(self, A: LinearOperator):
"""Store the linear operator whose trace will be estimated.
Args:
A: Linear square-shaped operator whose trace will be estimated.
Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> float:
"""Draw a sample from the trace estimator.
Multiple samples can be combined into a more accurate trace estimation via
averaging.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either `'rademacher'` or `'normal'`.
Default is `'rademacher'`.
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
Example:
>>> from numpy import trace, mean
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchinsonTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)])
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> tr_A, tr_A_low_precision, tr_A_high_precision
(4.457529730942303, 6.679568384120655, 4.388630875995861)
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_SAMPLINGS:
raise ValueError(
f"Unsupported distribution '{distribution}'. "
f"Supported distributions are {list(self.SUPPORTED_SAMPLINGS)}."
)

v = self.SUPPORTED_SAMPLINGS[distribution](dim)
Av = self._A @ v

return dot(v, Av)
29 changes: 29 additions & 0 deletions curvlinops/trace/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Sampling methods for random vectors."""

from numpy import ndarray
from numpy.random import binomial, randn


def rademacher(dim: int) -> ndarray:
"""Draw a vector with i.i.d. Rademacher elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. Rademacher elements and specified dimension.
"""
num_trials, success_prob = 1, 0.5
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1


def normal(dim: int) -> ndarray:
"""Drawa vector with i.i.d. standard normal elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. standard normal elements and specified dimension.
"""
return randn(dim)
4 changes: 2 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ install-test:
.PHONY: test test-light

test:
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops test
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops --doctest-modules curvlinops test

test-light:
@pytest -vx --cov=curvlinops test
@pytest -vx --cov=curvlinops --doctest-modules curvlinops test

.PHONY: install-lint

Expand Down
1 change: 1 addition & 0 deletions test/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test ``curvlinops.trace``."""
43 changes: 43 additions & 0 deletions test/trace/test__hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test ``curvlinops.trace.__init__``."""

from numpy import isclose, mean, trace
from numpy.random import rand, seed
from pytest import mark

from curvlinops import HutchinsonTraceEstimator

DISTRIBUTIONS = ["rademacher", "normal"]
DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS]


@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_HutchinsonTraceEstimator(distribution: str):
"""Test whether Hutchinon's trace estimator converges to the true trace.
Args:
distribution: Distribution of the random vectors used for the trace estimation.
"""
seed(0)
A = rand(10, 10)
tr_A = trace(A)

samples = []
max_samples = 20_000
chunk_size = 2_000 # add that many new samples before comparing against the truth
atol, rtol = 1e-3, 1e-2

estimator = HutchinsonTraceEstimator(A)

while len(samples) < max_samples:
samples.extend(
[estimator.sample(distribution=distribution) for _ in range(chunk_size)]
)
tr_estimator = mean(samples)
if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol):
print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}{tr_estimator:.5f}.")
else:
# quit once the estimator has converged
break

tr_estimator = mean(samples)
assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol)

0 comments on commit 0cb608a

Please sign in to comment.