Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] Vanilla Hutchinson trace estimation #38

Merged
merged 5 commits into from
Oct 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@ version: 2
sphinx:
configuration: docs/rtd/conf.py

build:
os: ubuntu-22.04
tools:
python: "3.8"

python:
version: "3.8"
install:
- method: pip
path: .
extra_requirements:
- docs
- method: pip
path: .
extra_requirements:
- docs
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
lanczos_approximate_spectrum,
)
from curvlinops.submatrix import SubmatrixLinearOperator
from curvlinops.trace.hutchinson import HutchinsonTraceEstimator

__all__ = [
"HessianLinearOperator",
Expand All @@ -28,4 +29,5 @@
"lanczos_approximate_log_spectrum",
"LanczosApproximateSpectrumCached",
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
]
1 change: 1 addition & 0 deletions curvlinops/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Trace estimation techniques."""
85 changes: 85 additions & 0 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Vanilla Hutchinson trace estimation."""

from typing import Callable, Dict

from numpy import dot, ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher


class HutchinsonTraceEstimator:
"""Class to perform trace estimation with Hutchinson's method.

For details, see

- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.

Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> tr_A = trace(A) # exact trace as reference
>>> estimator = HutchinsonTraceEstimator(A)
>>> # one- and multi-sample approximations
>>> tr_A_low_precision = estimator.sample()
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)])
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 6.6796, 4.3886)

Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(self, A: LinearOperator):
"""Store the linear operator whose trace will be estimated.

Args:
A: Linear square-shaped operator whose trace will be estimated.

Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> float:
"""Draw a sample from the trace estimator.

Multiple samples can be combined into a more accurate trace estimation via
averaging.

Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.

Returns:
Sample from the trace estimator.

Raises:
ValueError: If the distribution is not supported.
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
Av = self._A @ v

return dot(v, Av)
29 changes: 29 additions & 0 deletions curvlinops/trace/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Sampling methods for random vectors."""

from numpy import ndarray
from numpy.random import binomial, randn


def rademacher(dim: int) -> ndarray:
"""Draw a vector with i.i.d. Rademacher elements.

Args:
dim: Dimension of the vector.

Returns:
Vector with i.i.d. Rademacher elements and specified dimension.
"""
num_trials, success_prob = 1, 0.5
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1


def normal(dim: int) -> ndarray:
"""Drawa vector with i.i.d. standard normal elements.

Args:
dim: Dimension of the vector.

Returns:
Vector with i.i.d. standard normal elements and specified dimension.
"""
return randn(dim)
6 changes: 6 additions & 0 deletions docs/rtd/linops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,9 @@ Spectral density approximation

.. autoclass:: curvlinops.LanczosApproximateSpectrumCached
:members: __init__, approximate_spectrum

Trace approximation
===================

.. autoclass:: curvlinops.HutchinsonTraceEstimator
:members: __init__, sample
4 changes: 2 additions & 2 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,10 @@ install-test:
.PHONY: test test-light

test:
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops test
@pytest -vx --run-optional-tests=montecarlo --cov=curvlinops --doctest-modules curvlinops test

test-light:
@pytest -vx --cov=curvlinops test
@pytest -vx --cov=curvlinops --doctest-modules curvlinops test

.PHONY: install-lint

Expand Down
1 change: 1 addition & 0 deletions test/trace/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Test ``curvlinops.trace``."""
43 changes: 43 additions & 0 deletions test/trace/test_hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Test ``curvlinops.trace.__init__``."""

from numpy import isclose, mean, trace
from numpy.random import rand, seed
from pytest import mark

from curvlinops import HutchinsonTraceEstimator

DISTRIBUTIONS = ["rademacher", "normal"]
DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS]


@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS)
def test_HutchinsonTraceEstimator(distribution: str):
"""Test whether Hutchinon's trace estimator converges to the true trace.

Args:
distribution: Distribution of the random vectors used for the trace estimation.
"""
seed(0)
A = rand(10, 10)
tr_A = trace(A)

samples = []
max_samples = 20_000
chunk_size = 2_000 # add that many new samples before comparing against the truth
atol, rtol = 1e-3, 1e-2

estimator = HutchinsonTraceEstimator(A)

while len(samples) < max_samples:
samples.extend(
[estimator.sample(distribution=distribution) for _ in range(chunk_size)]
)
tr_estimator = mean(samples)
if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol):
print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}≠{tr_estimator:.5f}.")
else:
# quit once the estimator has converged
break

tr_estimator = mean(samples)
assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol)
Loading