Skip to content

Commit

Permalink
Add damping argument to KFACInverseLinearOperator
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Feb 6, 2024
1 parent 49014e9 commit 7a4a36e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 8 deletions.
31 changes: 26 additions & 5 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Implements linear operator inverses."""

from typing import Dict, Tuple
from typing import Dict, Optional, Tuple

from einops import rearrange
from numpy import allclose, column_stack, ndarray
from scipy.sparse.linalg import LinearOperator, cg
from torch import Tensor, cat, cholesky_inverse
from torch import Tensor, cat, cholesky_inverse, eye
from torch.linalg import cholesky

from curvlinops.kfac import KFACLinearOperator
Expand Down Expand Up @@ -209,11 +209,17 @@ def _matvec(self, x: ndarray) -> ndarray:
class KFACInverseLinearOperator(_InverseLinearOperator):
"""Class to invert instances of the ``KFACLinearOperator``."""

def __init__(self, A: KFACLinearOperator, cache: bool = True):
def __init__(
self,
A: KFACLinearOperator,
damping: Optional[Tuple[float, float]] = None,
cache: bool = True,
):
"""Store the linear operator whose inverse should be represented.
Args:
A: ``KFACLinearOperator`` whose inverse is formed.
damping: Damping values for all input and gradient covariances.
cache: Whether to cache the inverses of the Kronecker factors.
Default: ``True``.
Expand All @@ -226,6 +232,7 @@ def __init__(self, A: KFACLinearOperator, cache: bool = True):
)
super().__init__(A.dtype, A.shape)
self._A = A
self._damping = damping
self._cache = cache
self._inverse_input_covariances: Dict[str, Tensor] = {}
self._inverse_gradient_covariances: Dict[str, Tensor] = {}
Expand All @@ -245,8 +252,22 @@ def _compute_or_get_cached_inverse(self, name: str) -> Tuple[Tensor, Tensor]:
else:
aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv = cholesky_inverse(cholesky(aaT)) if aaT is not None else None
ggT_inv = cholesky_inverse(cholesky(ggT)) if ggT is not None else None
damping_aaT = self._damping[0] if self._damping is not None else 0.0
aaT_inv = (
cholesky_inverse(
cholesky(aaT + damping_aaT * eye(aaT.shape[0], device=aaT.device))
)
if aaT is not None
else None
)
damping_ggT = self._damping[1] if self._damping is not None else 0.0
ggT_inv = (
cholesky_inverse(
cholesky(ggT + damping_ggT * eye(ggT.shape[0], device=ggT.device))
)
if ggT is not None
else None
)
if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv
Expand Down
12 changes: 9 additions & 3 deletions test/test_inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,21 @@ def test_KFAC_inverse_damped_matvec(
separate_weight_and_bias=separate_weight_and_bias,
)
KFAC._compute_kfac()
# add damping

# add damping manually
for aaT in KFAC._input_covariances.values():
aaT.add_(torch.eye(aaT.shape[0], device=aaT.device), alpha=delta)
for ggT in KFAC._gradient_covariances.values():
ggT.add_(torch.eye(ggT.shape[0], device=ggT.device), alpha=delta)

inv_KFAC = KFACInverseLinearOperator(KFAC, cache=cache)
inv_KFAC_naive = torch.inverse(torch.as_tensor(KFAC @ eye(KFAC.shape[0])))

# remove damping and pass it on as an argument instead
for aaT in KFAC._input_covariances.values():
aaT.sub_(torch.eye(aaT.shape[0], device=aaT.device), alpha=delta)
for ggT in KFAC._gradient_covariances.values():
ggT.sub_(torch.eye(ggT.shape[0], device=ggT.device), alpha=delta)
inv_KFAC = KFACInverseLinearOperator(KFAC, damping=(delta, delta), cache=cache)

x = random.rand(KFAC.shape[1])
report_nonclose(inv_KFAC @ x, inv_KFAC_naive @ x, rtol=5e-2)

Expand Down

0 comments on commit 7a4a36e

Please sign in to comment.