From 42ea5534a2851cd958ad89d5de4ed683c6138c20 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Mon, 16 Oct 2023 14:53:05 -0400 Subject: [PATCH] [REF] Move example, polish rst --- curvlinops/trace/hutchinson.py | 42 +++++++++---------- ...test__hutchinson.py => test_hutchinson.py} | 0 2 files changed, 21 insertions(+), 21 deletions(-) rename test/trace/{test__hutchinson.py => test_hutchinson.py} (100%) diff --git a/curvlinops/trace/hutchinson.py b/curvlinops/trace/hutchinson.py index 8b5a8f5..fb69e16 100644 --- a/curvlinops/trace/hutchinson.py +++ b/curvlinops/trace/hutchinson.py @@ -17,12 +17,26 @@ class HutchinsonTraceEstimator: matrix for laplacian smoothing splines. Communication in Statistics---Simulation and Computation. + Example: + >>> from numpy import trace, mean + >>> from numpy.random import rand, seed + >>> seed(0) # make deterministic + >>> A = rand(10, 10) + >>> tr_A = trace(A) # exact trace as reference + >>> estimator = HutchinsonTraceEstimator(A) + >>> # one- and multi-sample approximations + >>> tr_A_low_precision = estimator.sample() + >>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)]) + >>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision) + >>> tr_A, tr_A_low_precision, tr_A_high_precision + (4.457529730942303, 6.679568384120655, 4.388630875995861) + Attributes: - SUPPORTED_SAMPLINGS: Dictionary mapping supported distributions to their + SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their sampling functions. """ - SUPPORTED_SAMPLINGS: Dict[str, Callable[[int], ndarray]] = { + SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = { "rademacher": rademacher, "normal": normal, } @@ -48,38 +62,24 @@ def sample(self, distribution: str = "rademacher") -> float: Args: distribution: Distribution of the vector along which the linear operator - will be evaluated. Either `'rademacher'` or `'normal'`. - Default is `'rademacher'`. + will be evaluated. Either ``'rademacher'`` or ``'normal'``. + Default is ``'rademacher'``. Returns: Sample from the trace estimator. Raises: ValueError: If the distribution is not supported. - - Example: - >>> from numpy import trace, mean - >>> from numpy.random import rand, seed - >>> seed(0) # make deterministic - >>> A = rand(10, 10) - >>> tr_A = trace(A) # exact trace as reference - >>> estimator = HutchinsonTraceEstimator(A) - >>> # one- and multi-sample approximations - >>> tr_A_low_precision = estimator.sample() - >>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)]) - >>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision) - >>> tr_A, tr_A_low_precision, tr_A_high_precision - (4.457529730942303, 6.679568384120655, 4.388630875995861) """ dim = self._A.shape[1] - if distribution not in self.SUPPORTED_SAMPLINGS: + if distribution not in self.SUPPORTED_DISTRIBUTIONS: raise ValueError( f"Unsupported distribution '{distribution}'. " - f"Supported distributions are {list(self.SUPPORTED_SAMPLINGS)}." + f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}." ) - v = self.SUPPORTED_SAMPLINGS[distribution](dim) + v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim) Av = self._A @ v return dot(v, Av) diff --git a/test/trace/test__hutchinson.py b/test/trace/test_hutchinson.py similarity index 100% rename from test/trace/test__hutchinson.py rename to test/trace/test_hutchinson.py