Skip to content

Commit

Permalink
[DOC] Add Neumann inverse to natural gradient example (#29)
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel authored Apr 6, 2023
1 parent 891195a commit 00733c2
Showing 1 changed file with 68 additions and 2 deletions.
70 changes: 68 additions & 2 deletions docs/examples/basic_usage/example_inverses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
This example demonstrates how to work with inverses of linear operators.
:code:`curvlinops` offers multiple ways to compute the inverse of a linear operator:
conjugate gradient (CG) and Neumann inversion. We will demonstrate CG inversion
first and conclude with a comparison to Neumann inversion.
Concretely, we will compute the natural gradient :math:`\mathbf{\tilde{g}} =
\mathbf{F}^{-1} \mathbf{g}`, defined by the inverse Fisher information
matrix :math:`\mathbf{F}^{-1}` and the gradient :math:`\mathbf{g}`. We can use
Expand All @@ -21,10 +25,14 @@
import numpy
import torch
from scipy import sparse
from scipy.sparse.linalg import aslinearoperator
from scipy.sparse.linalg import aslinearoperator, eigsh
from torch import nn

from curvlinops import CGInverseLinearOperator, GGNLinearOperator
from curvlinops import (
CGInverseLinearOperator,
GGNLinearOperator,
NeumannInverseLinearOperator,
)
from curvlinops.examples.functorch import functorch_ggn, functorch_gradient
from curvlinops.examples.utils import report_nonclose

Expand Down Expand Up @@ -218,3 +226,61 @@
ax[1].set_title("Inv. damped GGN/Fisher")
image = ax[1].imshow(numpy.log10(numpy.abs(inv_damped_GGN_mat)))
plt.colorbar(image, ax=ax[1], shrink=0.5)

# %%
#
# Neumann inverse (CG alternative)
# --------------------------------
#
# So far, we used CG to solve the linear system :math:`\mathbf{F}
# \mathbf{\tilde{g}} = \mathbf{g}` for the natural gradient
# :math:`\mathbf{\tilde{g}}` (i.e. the result of the inverse Fisher-gradient
# product). Alternatively, we can use the truncated `Neumann series
# <https://en.wikipedia.org/wiki/Neumann_series>`_ to approximate the inverse,
# using :py:class:`NeumannLinearOperator`.
#
# .. note::
# The Neumann series does not always converge. But we can use a re-scaling
# trick to make it converge if we know the matrix is PSD and are given its
# largest eigenvalue. More information can be found in the docstring.
#
# To make the Neumann series converge, we need to know the largest eigenvalue
# of the matrix to be inverted:
max_eigval = eigsh(damped_GGN, k=1, which="LM", return_eigenvectors=False)[0]
# eigenvalues (scale * damped_GGN_mat) are in [0; 2)
scale = 1.0 if max_eigval < 2.0 else 1.99 / max_eigval

# %%
#
# Let's compute the inverse approximation for different truncation numbers:

num_terms = [10]
neumann_inverses = []

for n in num_terms:
inv = NeumannInverseLinearOperator(damped_GGN, scale=scale, num_terms=n)
neumann_inverses.append(inv @ numpy.eye(inv.shape[1]))

# %%
#
# Here are their visualizations:

fig, axes = plt.subplots(ncols=len(num_terms) + 1)
plt.suptitle("Inverse damped Fisher (logarithm of absolute values)")

for i, (n, inv) in enumerate(zip(num_terms, neumann_inverses)):
ax = axes.flat[i]
ax.set_title(f"Neumann, {n} terms")
image = ax.imshow(numpy.log10(numpy.abs(inv)))
plt.colorbar(image, ax=ax, shrink=0.5)

ax = axes.flat[-1]
ax.set_title("Exact inverse")
image = ax.imshow(numpy.log10(numpy.abs(inv_damped_GGN_mat)))
plt.colorbar(image, ax=ax, shrink=0.5)

# %%
#
# The Neumann inversion is usually more inaccurate than CG inversion. But it
# might sometimes be preferred if only a rough approximation of the inverse
# matrix product is needed.

0 comments on commit 00733c2

Please sign in to comment.