From 1b912270fff1b94608bf9b23c942385112387dc1 Mon Sep 17 00:00:00 2001 From: Felix Dangel <48687646+f-dangel@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:31:05 -0400 Subject: [PATCH] [ADD] Vanilla Hutchinson trace estimation (#38) * [ADD] Hutchinson trace estimator * [DOC] Add Hutchinson estimator to documentation * [REF] Move example, polish rst * [RTD] Try fixing deprecated `build.image` https://blog.readthedocs.com/use-build-os-config/#use-build-os-instead-of-build-image-on-your-configuration-file * [FIX] Precision of doctest --- .readthedocs.yaml | 14 ++++-- curvlinops/__init__.py | 2 + curvlinops/trace/__init__.py | 1 + curvlinops/trace/hutchinson.py | 85 ++++++++++++++++++++++++++++++++++ curvlinops/trace/sampling.py | 29 ++++++++++++ docs/rtd/linops.rst | 6 +++ makefile | 4 +- test/trace/__init__.py | 1 + test/trace/test_hutchinson.py | 43 +++++++++++++++++ 9 files changed, 178 insertions(+), 7 deletions(-) create mode 100644 curvlinops/trace/__init__.py create mode 100644 curvlinops/trace/hutchinson.py create mode 100644 curvlinops/trace/sampling.py create mode 100644 test/trace/__init__.py create mode 100644 test/trace/test_hutchinson.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 692e1a5..3d33c3e 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,10 +6,14 @@ version: 2 sphinx: configuration: docs/rtd/conf.py +build: + os: ubuntu-22.04 + tools: + python: "3.8" + python: - version: "3.8" install: - - method: pip - path: . - extra_requirements: - - docs + - method: pip + path: . + extra_requirements: + - docs diff --git a/curvlinops/__init__.py b/curvlinops/__init__.py index 37bf56d..f61b8cc 100644 --- a/curvlinops/__init__.py +++ b/curvlinops/__init__.py @@ -13,6 +13,7 @@ lanczos_approximate_spectrum, ) from curvlinops.submatrix import SubmatrixLinearOperator +from curvlinops.trace.hutchinson import HutchinsonTraceEstimator __all__ = [ "HessianLinearOperator", @@ -28,4 +29,5 @@ "lanczos_approximate_log_spectrum", "LanczosApproximateSpectrumCached", "LanczosApproximateLogSpectrumCached", + "HutchinsonTraceEstimator", ] diff --git a/curvlinops/trace/__init__.py b/curvlinops/trace/__init__.py new file mode 100644 index 0000000..81c8129 --- /dev/null +++ b/curvlinops/trace/__init__.py @@ -0,0 +1 @@ +"""Trace estimation techniques.""" diff --git a/curvlinops/trace/hutchinson.py b/curvlinops/trace/hutchinson.py new file mode 100644 index 0000000..79f013c --- /dev/null +++ b/curvlinops/trace/hutchinson.py @@ -0,0 +1,85 @@ +"""Vanilla Hutchinson trace estimation.""" + +from typing import Callable, Dict + +from numpy import dot, ndarray +from scipy.sparse.linalg import LinearOperator + +from curvlinops.trace.sampling import normal, rademacher + + +class HutchinsonTraceEstimator: + """Class to perform trace estimation with Hutchinson's method. + + For details, see + + - Hutchinson, M. (1989). A stochastic estimator of the trace of the influence + matrix for laplacian smoothing splines. Communication in Statistics---Simulation + and Computation. + + Example: + >>> from numpy import trace, mean, round + >>> from numpy.random import rand, seed + >>> seed(0) # make deterministic + >>> A = rand(10, 10) + >>> tr_A = trace(A) # exact trace as reference + >>> estimator = HutchinsonTraceEstimator(A) + >>> # one- and multi-sample approximations + >>> tr_A_low_precision = estimator.sample() + >>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)]) + >>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision) + >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) + (4.4575, 6.6796, 4.3886) + + Attributes: + SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their + sampling functions. + """ + + SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = { + "rademacher": rademacher, + "normal": normal, + } + + def __init__(self, A: LinearOperator): + """Store the linear operator whose trace will be estimated. + + Args: + A: Linear square-shaped operator whose trace will be estimated. + + Raises: + ValueError: If the operator is not square. + """ + if len(A.shape) != 2 or A.shape[0] != A.shape[1]: + raise ValueError(f"A must be square. Got shape {A.shape}.") + self._A = A + + def sample(self, distribution: str = "rademacher") -> float: + """Draw a sample from the trace estimator. + + Multiple samples can be combined into a more accurate trace estimation via + averaging. + + Args: + distribution: Distribution of the vector along which the linear operator + will be evaluated. Either ``'rademacher'`` or ``'normal'``. + Default is ``'rademacher'``. + + Returns: + Sample from the trace estimator. + + Raises: + ValueError: If the distribution is not supported. + """ + dim = self._A.shape[1] + + if distribution not in self.SUPPORTED_DISTRIBUTIONS: + raise ValueError( + f"Unsupported distribution {distribution:!r}. " + f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}." + ) + + v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim) + Av = self._A @ v + + return dot(v, Av) diff --git a/curvlinops/trace/sampling.py b/curvlinops/trace/sampling.py new file mode 100644 index 0000000..93ffbde --- /dev/null +++ b/curvlinops/trace/sampling.py @@ -0,0 +1,29 @@ +"""Sampling methods for random vectors.""" + +from numpy import ndarray +from numpy.random import binomial, randn + + +def rademacher(dim: int) -> ndarray: + """Draw a vector with i.i.d. Rademacher elements. + + Args: + dim: Dimension of the vector. + + Returns: + Vector with i.i.d. Rademacher elements and specified dimension. + """ + num_trials, success_prob = 1, 0.5 + return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1 + + +def normal(dim: int) -> ndarray: + """Drawa vector with i.i.d. standard normal elements. + + Args: + dim: Dimension of the vector. + + Returns: + Vector with i.i.d. standard normal elements and specified dimension. + """ + return randn(dim) diff --git a/docs/rtd/linops.rst b/docs/rtd/linops.rst index 9b4a0b2..36cf046 100644 --- a/docs/rtd/linops.rst +++ b/docs/rtd/linops.rst @@ -59,3 +59,9 @@ Spectral density approximation .. autoclass:: curvlinops.LanczosApproximateSpectrumCached :members: __init__, approximate_spectrum + +Trace approximation +=================== + +.. autoclass:: curvlinops.HutchinsonTraceEstimator + :members: __init__, sample diff --git a/makefile b/makefile index cf2a7be..afb5c68 100644 --- a/makefile +++ b/makefile @@ -69,10 +69,10 @@ install-test: .PHONY: test test-light test: - @pytest -vx --run-optional-tests=montecarlo --cov=curvlinops test + @pytest -vx --run-optional-tests=montecarlo --cov=curvlinops --doctest-modules curvlinops test test-light: - @pytest -vx --cov=curvlinops test + @pytest -vx --cov=curvlinops --doctest-modules curvlinops test .PHONY: install-lint diff --git a/test/trace/__init__.py b/test/trace/__init__.py new file mode 100644 index 0000000..e4ff235 --- /dev/null +++ b/test/trace/__init__.py @@ -0,0 +1 @@ +"""Test ``curvlinops.trace``.""" diff --git a/test/trace/test_hutchinson.py b/test/trace/test_hutchinson.py new file mode 100644 index 0000000..6ee8bb6 --- /dev/null +++ b/test/trace/test_hutchinson.py @@ -0,0 +1,43 @@ +"""Test ``curvlinops.trace.__init__``.""" + +from numpy import isclose, mean, trace +from numpy.random import rand, seed +from pytest import mark + +from curvlinops import HutchinsonTraceEstimator + +DISTRIBUTIONS = ["rademacher", "normal"] +DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS] + + +@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) +def test_HutchinsonTraceEstimator(distribution: str): + """Test whether Hutchinon's trace estimator converges to the true trace. + + Args: + distribution: Distribution of the random vectors used for the trace estimation. + """ + seed(0) + A = rand(10, 10) + tr_A = trace(A) + + samples = [] + max_samples = 20_000 + chunk_size = 2_000 # add that many new samples before comparing against the truth + atol, rtol = 1e-3, 1e-2 + + estimator = HutchinsonTraceEstimator(A) + + while len(samples) < max_samples: + samples.extend( + [estimator.sample(distribution=distribution) for _ in range(chunk_size)] + ) + tr_estimator = mean(samples) + if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol): + print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}≠{tr_estimator:.5f}.") + else: + # quit once the estimator has converged + break + + tr_estimator = mean(samples) + assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol)