Skip to content

Commit

Permalink
[ADD] Vanilla Hutchinson trace estimation (#38)
Browse files Browse the repository at this point in the history
* [ADD] Hutchinson trace estimator

* [DOC] Add Hutchinson estimator to documentation

* [REF] Move example, polish rst

* [RTD] Try fixing deprecated `build.image`

https://blog.readthedocs.com/use-build-os-config/#use-build-os-instead-of-build-image-on-your-configuration-file

* [FIX] Precision of doctest
  • Loading branch information
f-dangel authored Oct 16, 2023
1 parent fb6ba25 commit 1b91227
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 7 deletions.
14 changes: 9 additions & 5 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ version: 2
sphinx:
configuration: docs/rtd/conf.py

build:
os: ubuntu-22.04
tools:
python: "3.8"

python:
version: "3.8"
install:
- method: pip
path: .
extra_requirements:
- docs
- method: pip
path: .
extra_requirements:
- docs
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
lanczos_approximate_spectrum,
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator

__all__ = [
"HessianLinearOperator",
Expand All @@ -28,4 +29,5 @@
"lanczos_approximate_log_spectrum",
"LanczosApproximateSpectrumCached",
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
]
1 change: 1 addition & 0 deletions curvlinops/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Trace estimation techniques."""
85 changes: 85 additions & 0 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Vanilla Hutchinson trace estimation."""

from typing import Callable, Dict

from numpy import dot, ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher


class HutchinsonTraceEstimator:
"""Class to perform trace estimation with Hutchinson's method.
For details, see
- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.
Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchinsonTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)])
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 6.6796, 4.3886)
Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(self, A: LinearOperator):
"""Store the linear operator whose trace will be estimated.
Args:
A: Linear square-shaped operator whose trace will be estimated.
Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> float:
"""Draw a sample from the trace estimator.
Multiple samples can be combined into a more accurate trace estimation via
averaging.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
Av = self._A @ v

return dot(v, Av)
29 changes: 29 additions & 0 deletions curvlinops/trace/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Sampling methods for random vectors."""

from numpy import ndarray
from numpy.random import binomial, randn


def rademacher(dim: int) -> ndarray:
"""Draw a vector with i.i.d. Rademacher elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. Rademacher elements and specified dimension.
"""
num_trials, success_prob = 1, 0.5
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1


def normal(dim: int) -> ndarray:
"""Drawa vector with i.i.d. standard normal elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. standard normal elements and specified dimension.
"""
return randn(dim)
6 changes: 6 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ Spectral density approximation

.. autoclass:: curvlinops.LanczosApproximateSpectrumCached
:members: __init__, approximate_spectrum

Trace approximation
===================

.. autoclass:: curvlinops.HutchinsonTraceEstimator
:members: __init__, sample
4 changes: 2 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ install-test:
.PHONY: test test-light

test:
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops test
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops --doctest-modules curvlinops test

test-light:
@pytest -vx --cov=curvlinops test
@pytest -vx --cov=curvlinops --doctest-modules curvlinops test

.PHONY: install-lint

Expand Down
1 change: 1 addition & 0 deletions test/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test ``curvlinops.trace``."""
43 changes: 43 additions & 0 deletions test/trace/test_hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test ``curvlinops.trace.__init__``."""

from numpy import isclose, mean, trace
from numpy.random import rand, seed
from pytest import mark

from curvlinops import HutchinsonTraceEstimator

DISTRIBUTIONS = ["rademacher", "normal"]
DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS]


@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_HutchinsonTraceEstimator(distribution: str):
"""Test whether Hutchinon's trace estimator converges to the true trace.
Args:
distribution: Distribution of the random vectors used for the trace estimation.
"""
seed(0)
A = rand(10, 10)
tr_A = trace(A)

samples = []
max_samples = 20_000
chunk_size = 2_000 # add that many new samples before comparing against the truth
atol, rtol = 1e-3, 1e-2

estimator = HutchinsonTraceEstimator(A)

while len(samples) < max_samples:
samples.extend(
[estimator.sample(distribution=distribution) for _ in range(chunk_size)]
)
tr_estimator = mean(samples)
if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol):
print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}{tr_estimator:.5f}.")
else:
# quit once the estimator has converged
break

tr_estimator = mean(samples)
assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol)

0 comments on commit 1b91227

Please sign in to comment.