Skip to content

Commit

Permalink
[ADD] Hutchinson-style matrix diagonal estimation (#40)
Browse files Browse the repository at this point in the history
* [REF] Use same code in trace tests, extract random vector generation

* [ADD] Hutchinson-style diagonal estimation

* [DOC] Add diagonal estimator to documentation

* [DOC] Short summary for each trace/diagonal estimation method
  • Loading branch information
f-dangel authored Oct 17, 2023
1 parent 42913be commit c0f66e1
Show file tree
Hide file tree
Showing 13 changed files with 303 additions and 134 deletions.
2 changes: 2 additions & 0 deletions curvlinops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""``curvlinops`` library API."""

from curvlinops.diagonal.hutchinson import HutchinsonDiagonalEstimator
from curvlinops.fisher import FisherMCLinearOperator
from curvlinops.ggn import GGNLinearOperator
from curvlinops.gradient_moments import EFLinearOperator
Expand Down Expand Up @@ -32,4 +33,5 @@
"LanczosApproximateLogSpectrumCached",
"HutchinsonTraceEstimator",
"HutchPPTraceEstimator",
"HutchinsonDiagonalEstimator",
]
1 change: 1 addition & 0 deletions curvlinops/diagonal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Matrix diagonal estimation methods."""
87 changes: 87 additions & 0 deletions curvlinops/diagonal/hutchinson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""Hutchinson-style matrix diagonal estimation."""


from numpy import ndarray
from scipy.sparse.linalg import LinearOperator

from curvlinops.sampling import random_vector


class HutchinsonDiagonalEstimator:
r"""Class to perform diagonal estimation with Hutchinson's method.
For details, see
- Martens, J., Sutskever, I., & Swersky, K. (2012). Estimating the hessian by
back-propagating curvature. International Conference on Machine Learning (ICML).
Let :math:`\mathbf{A}` be a square linear operator. We can approximate its diagonal
:math:`\mathrm{diag}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}`
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and
sample from the estimator
.. math::
\mathbf{a}
:= \mathbf{v} \odot \mathbf{A} \mathbf{v}
\approx \mathrm{diag}(\mathbf{A})\,.
This estimator is unbiased,
.. math::
\mathbb{E}[a_i]
= \sum_j \mathbb{E}[v_i A_{i,j} v_j]
= \sum_j A_{i,j} \mathbb{E}[v_i v_j]
= \sum_j A_{i,j} \delta_{i, j}
= A_{i,i}\,.
Example:
>>> from numpy import diag, mean, round
>>> from numpy.random import rand, seed
>>> from numpy.linalg import norm
>>> seed(0) # make deterministic
>>> A = rand(10, 10)
>>> diag_A = diag(A) # exact diagonal as reference
>>> estimator = HutchinsonDiagonalEstimator(A)
>>> # one- and multi-sample approximations
>>> diag_A_low_precision = estimator.sample()
>>> samples = [estimator.sample() for _ in range(1_000)]
>>> diag_A_high_precision = mean(samples, axis=0)
>>> # compute residual norms
>>> error_low_precision = norm(diag_A - diag_A_low_precision)
>>> error_high_precision = norm(diag_A - diag_A_high_precision)
>>> assert error_low_precision > error_high_precision
>>> round(error_low_precision, 4), round(error_high_precision, 4)
(5.7268, 0.1525)
"""

def __init__(self, A: LinearOperator):
"""Store the linear operator whose diagonal will be estimated.
Args:
A: Linear square-shaped operator whose diagonal will be estimated.
Raises:
ValueError: If the operator is not square.
"""
if len(A.shape) != 2 or A.shape[0] != A.shape[1]:
raise ValueError(f"A must be square. Got shape {A.shape}.")
self._A = A

def sample(self, distribution: str = "rademacher") -> ndarray:
"""Draw a sample from the diagonal estimator.
Multiple samples can be combined into a more accurate diagonal estimation via
averaging.
Args:
distribution: Distribution of the vector along which the linear operator
will be evaluated. Either ``'rademacher'`` or ``'normal'``.
Default is ``'rademacher'``.
Returns:
A Sample from the diagonal estimator.
"""
dim = self._A.shape[1]
v = random_vector(dim, distribution)
Av = self._A @ v
return v * Av
51 changes: 51 additions & 0 deletions curvlinops/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Sampling methods for random vectors."""

from numpy import ndarray
from numpy.random import binomial, randn


def rademacher(dim: int) -> ndarray:
"""Draw a vector with i.i.d. Rademacher elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. Rademacher elements and specified dimension.
"""
num_trials, success_prob = 1, 0.5
return binomial(num_trials, success_prob, size=dim).astype(float) * 2 - 1


def normal(dim: int) -> ndarray:
"""Draw a vector with i.i.d. standard normal elements.
Args:
dim: Dimension of the vector.
Returns:
Vector with i.i.d. standard normal elements and specified dimension.
"""
return randn(dim)


def random_vector(dim: int, distribution: str) -> ndarray:
"""Draw a vector with i.i.d. elements.
Args:
dim: Dimension of the vector.
distribution: Distribution of the vector's elements. Either ``'rademacher'`` or
``'normal'``.
Returns:
Vector with i.i.d. elements and specified dimension.
Raises:
ValueError: If the distribution is unknown.
"""
if distribution == "rademacher":
return rademacher(dim)
elif distribution == "normal":
return normal(dim)
else:
raise ValueError(f"Unknown distribution {distribution:!r}.")
49 changes: 23 additions & 26 deletions curvlinops/trace/hutchinson.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,39 @@
"""Vanilla Hutchinson trace estimation."""

from typing import Callable, Dict

from numpy import dot, ndarray
from numpy import dot
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher
from curvlinops.sampling import random_vector


class HutchinsonTraceEstimator:
"""Class to perform trace estimation with Hutchinson's method.
r"""Class to perform trace estimation with Hutchinson's method.
For details, see
- Hutchinson, M. (1989). A stochastic estimator of the trace of the influence
matrix for laplacian smoothing splines. Communication in Statistics---Simulation
and Computation.
Let :math:`\mathbf{A}` be a square linear operator. We can approximate its trace
:math:`\mathrm{Tr}(\mathbf{A})` by drawing a random vector :math:`\mathbf{v}`
which satisfies :math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and
sample from the estimator
.. math::
a
:= \mathbf{v}^\top \mathbf{A} \mathbf{v}
\approx \mathrm{Tr}(\mathbf{A})\,.
This estimator is unbiased,
.. math::
\mathbb{E}[a]
= \mathrm{Tr}(\mathbb{E}[\mathbf{v}^\top\mathbf{A} \mathbf{v}])
= \mathrm{Tr}(\mathbf{A} \mathbb{E}[\mathbf{v} \mathbf{v}^\top])
= \mathrm{Tr}(\mathbf{A} \mathbf{I})
= \mathrm{Tr}(\mathbf{A})\,.
Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
Expand All @@ -30,17 +47,8 @@ class HutchinsonTraceEstimator:
>>> assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 6.6796, 4.3886)
Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(self, A: LinearOperator):
"""Store the linear operator whose trace will be estimated.
Expand All @@ -67,19 +75,8 @@ def sample(self, distribution: str = "rademacher") -> float:
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
"""
dim = self._A.shape[1]

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
v = random_vector(dim, distribution)
Av = self._A @ v

return dot(v, Av)
60 changes: 27 additions & 33 deletions curvlinops/trace/meyer2020hutch.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Implementation of Hutch++ trace estimation from Meyer et al."""

from typing import Callable, Dict, Optional, Union
from typing import Optional, Union

from numpy import column_stack, dot, ndarray
from numpy.linalg import qr
from scipy.sparse.linalg import LinearOperator

from curvlinops.trace.sampling import normal, rademacher
from curvlinops.sampling import random_vector


class HutchPPTraceEstimator:
"""Class to perform trace estimation with the Huch++ method.
r"""Class to perform trace estimation with the Huch++ method.
In contrast to vanilla Hutchinson, Hutch++ has lower variance, but requires more
memory.
Expand All @@ -20,6 +20,26 @@ class HutchPPTraceEstimator:
- Meyer, R. A., Musco, C., Musco, C., & Woodruff, D. P. (2020). Hutch++:
optimal stochastic trace estimation.
Let :math:`\mathbf{A}` be a square linear operator whose trace we want to
approximate. First, we compute an orthonormal basis :math:`\mathbf{Q}` of a
sub-space spanned by :math:`\mathbf{A} \mathbf{S}` where :math:`\mathbf{S}` is a
tall random matrix with i.i.d. elements. Then, we compute the trace in the sub-space
and apply Hutchinson's estimator in the remaining space spanned by
:math:`\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top`: We can draw a random vector
:math:`\mathbf{v}` which satisfies
:math:`\mathbb{E}[\mathbf{v} \mathbf{v}^\top] = \mathbf{I}` and sample from the
estimator
.. math::
a
:= \mathrm{Tr}(\mathbf{Q}^\top \mathbf{A} \mathbf{Q})
+ \mathbf{v}^\top (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top)^\top
\mathbf{A} (\mathbf{I} - \mathbf{Q} \mathbf{Q}^\top) \mathbf{v}
\approx \mathrm{Tr}(\mathbf{A})\,.
This estimator is unbiased, :math:`\mathbb{E}[a] = \mathrm{Tr}(\mathbf{A})`, as the
first term is constant and the second part is Hutchinson's estimator in a sub-space.
Example:
>>> from numpy import trace, mean, round
>>> from numpy.random import rand, seed
Expand All @@ -33,17 +53,8 @@ class HutchPPTraceEstimator:
>>> # assert abs(tr_A - tr_A_low_precision) > abs(tr_A - tr_A_high_precision)
>>> round(tr_A, 4), round(tr_A_low_precision, 4), round(tr_A_high_precision, 4)
(4.4575, 2.4085, 4.5791)
Attributes:
SUPPORTED_DISTRIBUTIONS: Dictionary mapping supported distributions to their
sampling functions.
"""

SUPPORTED_DISTRIBUTIONS: Dict[str, Callable[[int], ndarray]] = {
"rademacher": rademacher,
"normal": normal,
}

def __init__(
self,
A: LinearOperator,
Expand All @@ -64,8 +75,8 @@ def __init__(
``'rademacher'``.
Raises:
ValueError: If the operator is not square, the basis dimension is too
large, or the sampling distribution is not supported.
ValueError: If the operator is not square or the basis dimension is too
large.
Note:
If you are planning to perform a fair (i.e. same computation budget)
Expand All @@ -86,12 +97,6 @@ def __init__(
f"Basis dimension must be at most {self._A.shape[1]}. Got {basis_dim}."
)
self._basis_dim = basis_dim

if basis_distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {basis_distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)
self._basis_distribution = basis_distribution

# When drawing the first sample, the basis and its subspace trace will be
Expand All @@ -116,25 +121,14 @@ def sample(self, distribution: str = "rademacher") -> float:
Returns:
Sample from the trace estimator.
Raises:
ValueError: If the distribution is not supported.
"""
self.maybe_compute_and_cache_subspace()

if distribution not in self.SUPPORTED_DISTRIBUTIONS:
raise ValueError(
f"Unsupported distribution {distribution:!r}. "
f"Supported distributions are {list(self.SUPPORTED_DISTRIBUTIONS)}."
)

dim = self._A.shape[1]
v = self.SUPPORTED_DISTRIBUTIONS[distribution](dim)
v = random_vector(dim, distribution)
# project out subspace
v -= self._Q @ (self._Q.T @ v)

Av = self._A @ v

return self._tr_QT_A_Q + dot(v, Av)

def maybe_compute_and_cache_subspace(self):
Expand All @@ -145,7 +139,7 @@ def maybe_compute_and_cache_subspace(self):
dim = self._A.shape[1]
AS = column_stack(
[
self._A @ self.SUPPORTED_DISTRIBUTIONS[self._basis_distribution](dim)
self._A @ random_vector(dim, self._basis_distribution)
for _ in range(self._basis_dim)
]
)
Expand Down
29 changes: 0 additions & 29 deletions curvlinops/trace/sampling.py

This file was deleted.

Loading

0 comments on commit c0f66e1

Please sign in to comment.