diff --git a/curvlinops/norm/hutchinson.py b/curvlinops/norm/hutchinson.py index 0d5000c..4d0818b 100644 --- a/curvlinops/norm/hutchinson.py +++ b/curvlinops/norm/hutchinson.py @@ -1,8 +1,9 @@ """Hutchinson-style matrix norm estimation.""" +from numpy import dot from scipy.sparse.linalg import LinearOperator -from curvlinops.trace.hutchinson import HutchinsonTraceEstimator +from curvlinops.sampling import random_vector class HutchinsonSquaredFrobeniusNormEstimator: @@ -43,7 +44,7 @@ def __init__(self, A: LinearOperator): Args: A: Linear operator whose squared Frobenius norm will be estimated. """ - self._trace_estimator = HutchinsonTraceEstimator(A.T @ A) + self._A = A def sample(self, distribution: str = "rademacher") -> float: """Draw a sample from the squared Frobenius norm estimator. @@ -59,4 +60,7 @@ def sample(self, distribution: str = "rademacher") -> float: Returns: Sample from the squared Frobenius norm estimator. """ - return self._trace_estimator.sample(distribution=distribution) + dim = self._A.shape[1] + v = random_vector(dim, distribution) + Av = self._A @ v + return dot(Av, Av)