diff --git a/curvlinops/trace/hutchinson.py b/curvlinops/trace/hutchinson.py index fb69e16..79f013c 100644 --- a/curvlinops/trace/hutchinson.py +++ b/curvlinops/trace/hutchinson.py @@ -18,7 +18,7 @@ class HutchinsonTraceEstimator: and Computation. Example: - >>> from numpy import trace, mean + >>> from numpy import trace, mean, round >>> from numpy.random import rand, seed >>> seed(0) # make deterministic >>> A = rand(10, 10) @@ -28,8 +28,8 @@ class HutchinsonTraceEstimator: >>> tr_A_low_precision = estimator.sample() >>> tr_A_high_precision = mean([estimator.sample() for _ in range(1_000)]) >>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision) - >>> tr_A, tr_A_low_precision, tr_A_high_precision - (4.457529730942303, 6.679568384120655, 4.388630875995861) + >>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4) + (4.4575, 6.6796, 4.3886) Attributes: SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their @@ -75,7 +75,7 @@ def sample(self, distribution: str = "rademacher") -> float: if distribution not in self.SUPPORTED_DISTRIBUTIONS: raise ValueError( - f"Unsupported distribution '{distribution}'. " + f"Unsupported distribution {distribution:!r}. " f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}." )