-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
163 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Trace estimation techniques.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
"""Vanilla Hutchinson trace estimation.""" | ||
|
||
from typing import Callable, Dict | ||
|
||
from numpy import dot, ndarray | ||
from scipy.sparse.linalg import LinearOperator | ||
|
||
from curvlinops.trace.sampling import normal, rademacher | ||
|
||
|
||
class HutchinsonTraceEstimator: | ||
"""Class to perform trace estimation with Hutchinson's method. | ||
For details, see | ||
- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence | ||
matrix for laplacian smoothing splines. Communication in Statistics---Simulation | ||
and Computation. | ||
Attributes: | ||
SUPPORTED_SAMPLINGS: Dictionary mapping supported distributions to their | ||
sampling functions. | ||
""" | ||
|
||
SUPPORTED_SAMPLINGS: Dict[str, Callable[[int], ndarray]] = { | ||
"rademacher": rademacher, | ||
"normal": normal, | ||
} | ||
|
||
def __init__(self, A: LinearOperator): | ||
"""Store the linear operator whose trace will be estimated. | ||
Args: | ||
A: Linear square-shaped operator whose trace will be estimated. | ||
Raises: | ||
ValueError: If the operator is not square. | ||
""" | ||
if len(A.shape) != 2 or A.shape[0] != A.shape[1]: | ||
raise ValueError(f"A must be square. Got shape {A.shape}.") | ||
self._A = A | ||
|
||
def sample(self, distribution: str = "rademacher") -> float: | ||
"""Draw a sample from the trace estimator. | ||
Multiple samples can be combined into a more accurate trace estimation via | ||
averaging. | ||
Args: | ||
distribution: Distribution of the vector along which the linear operator | ||
will be evaluated. Either `'rademacher'` or `'normal'`. | ||
Default is `'rademacher'`. | ||
Returns: | ||
Sample from the trace estimator. | ||
Raises: | ||
ValueError: If the distribution is not supported. | ||
Example: | ||
>>> from numpy import trace, mean | ||
>>> from numpy.random import rand, seed | ||
>>> seed(0) # make deterministic | ||
>>> A = rand(10, 10) | ||
>>> tr_A = trace(A) # exact trace as reference | ||
>>> estimator = HutchinsonTraceEstimator(A) | ||
>>> # one- and multi-sample approximations | ||
>>> tr_A_low_precision = estimator.sample() | ||
>>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)]) | ||
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision) | ||
>>> tr_A, tr_A_low_precision, tr_A_high_precision | ||
(4.457529730942303, 6.679568384120655, 4.388630875995861) | ||
""" | ||
dim = self._A.shape[1] | ||
|
||
if distribution not in self.SUPPORTED_SAMPLINGS: | ||
raise ValueError( | ||
f"Unsupported distribution '{distribution}'. " | ||
f"Supported distributions are {list(self.SUPPORTED_SAMPLINGS)}." | ||
) | ||
|
||
v = self.SUPPORTED_SAMPLINGS[distribution](dim) | ||
Av = self._A @ v | ||
|
||
return dot(v, Av) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Sampling methods for random vectors.""" | ||
|
||
from numpy import ndarray | ||
from numpy.random import binomial, randn | ||
|
||
|
||
def rademacher(dim: int) -> ndarray: | ||
"""Draw a vector with i.i.d. Rademacher elements. | ||
Args: | ||
dim: Dimension of the vector. | ||
Returns: | ||
Vector with i.i.d. Rademacher elements and specified dimension. | ||
""" | ||
num_trials, success_prob = 1, 0.5 | ||
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1 | ||
|
||
|
||
def normal(dim: int) -> ndarray: | ||
"""Drawa vector with i.i.d. standard normal elements. | ||
Args: | ||
dim: Dimension of the vector. | ||
Returns: | ||
Vector with i.i.d. standard normal elements and specified dimension. | ||
""" | ||
return randn(dim) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Test ``curvlinops.trace``.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
"""Test ``curvlinops.trace.__init__``.""" | ||
|
||
from numpy import isclose, mean, trace | ||
from numpy.random import rand, seed | ||
from pytest import mark | ||
|
||
from curvlinops import HutchinsonTraceEstimator | ||
|
||
DISTRIBUTIONS = ["rademacher", "normal"] | ||
DISTRIBUTION_IDS = [f"distribution={distribution}" for distribution in DISTRIBUTIONS] | ||
|
||
|
||
@mark.parametrize("distribution", DISTRIBUTIONS, ids=DISTRIBUTION_IDS) | ||
def test_HutchinsonTraceEstimator(distribution: str): | ||
"""Test whether Hutchinon's trace estimator converges to the true trace. | ||
Args: | ||
distribution: Distribution of the random vectors used for the trace estimation. | ||
""" | ||
seed(0) | ||
A = rand(10, 10) | ||
tr_A = trace(A) | ||
|
||
samples = [] | ||
max_samples = 20_000 | ||
chunk_size = 2_000 # add that many new samples before comparing against the truth | ||
atol, rtol = 1e-3, 1e-2 | ||
|
||
estimator = HutchinsonTraceEstimator(A) | ||
|
||
while len(samples) < max_samples: | ||
samples.extend( | ||
[estimator.sample(distribution=distribution) for _ in range(chunk_size)] | ||
) | ||
tr_estimator = mean(samples) | ||
if not isclose(tr_A, tr_estimator, atol=atol, rtol=rtol): | ||
print(f"{len(samples)} samples: Tr(A)={tr_A:.5f}≠{tr_estimator:.5f}.") | ||
else: | ||
# quit once the estimator has converged | ||
break | ||
|
||
tr_estimator = mean(samples) | ||
assert isclose(tr_A, tr_estimator, atol=atol, rtol=rtol) |