Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ADD] EKFAC #127

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
206 changes: 86 additions & 120 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange
from einops import rearrange
from numpy import allclose, column_stack, ndarray
from scipy.sparse.linalg import LinearOperator, cg, lsmr
from torch import Tensor, cat, cholesky_inverse, eye, float64, outer
from torch.linalg import cholesky, eigh

from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType

KFACInvType = TypeVar(
"KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
)
from curvlinops.kfac import FactorType, KFACLinearOperator, ParameterMatrixType


class _InverseLinearOperator(LinearOperator):
Expand Down Expand Up @@ -355,15 +351,17 @@ def __init__(
raise ValueError(
"Heuristic and exact damping require a single damping value."
)
if self._A._correct_eigenvalues and not use_exact_damping:
raise ValueError("Only exact damping is supported for EKFAC.")

self._damping = damping
self._use_heuristic_damping = use_heuristic_damping
self._min_damping = min_damping
self._use_exact_damping = use_exact_damping
self._cache = cache
self._retry_double_precision = retry_double_precision
self._inverse_input_covariances: Dict[str, KFACInvType] = {}
self._inverse_gradient_covariances: Dict[str, KFACInvType] = {}
self._inverse_input_covariances: Dict[str, FactorType] = {}
self._inverse_gradient_covariances: Dict[str, FactorType] = {}

def _compute_damping(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
Expand Down Expand Up @@ -407,19 +405,54 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor:
M.add(eye(M.shape[0], dtype=M.dtype, device=M.device), alpha=damping)
)

def _compute_inv_damped_eigenvalues(
self, aaT_eigvals: Tensor, ggT_eigvals: Tensor, name: str
) -> Union[Tensor, Dict[str, Tensor]]:
"""Compute the inverses of the damped eigenvalues for a given layer.

Args:
aaT_eigvals: Eigenvalues of the input covariance matrix.
ggT_eigvals: Eigenvalues of the gradient covariance matrix.
name: Name of the layer for which to damp and invert eigenvalues.

Returns:
Inverses of the damped eigenvalues.
"""
param_pos = self._A._mapping[name]
if (
not self._A._separate_weight_and_bias
and "weight" in param_pos
and "bias" in param_pos
):
inv_damped_eigenvalues = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues = {}
for p_name, pos in param_pos.items():
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals)
if p_name == "weight"
else ggT_eigvals
)
inv_damped_eigenvalues[pos].add_(self._damping).pow_(-1)
return inv_damped_eigenvalues

def _compute_inverse_factors(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
) -> Tuple[KFACInvType, KFACInvType]:
self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str
) -> Tuple[FactorType, FactorType, Optional[Tensor]]:
"""Compute the inverses of the Kronecker factors for a given layer.

Args:
aaT: Input covariance matrix. ``None`` for biases.
ggT: Gradient covariance matrix.
name: Name of the layer for which to invert Kronecker factors.

Returns:
Tuple of inverses (or eigendecompositions) of the input and gradient
covariance Kronecker factors. Can be ``None`` if the input or gradient
covariance is ``None`` (e.g. the input covariances for biases).
covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if
the input or gradient covariance is ``None`` (e.g. the input covariances for
biases).

Raises:
RuntimeError: If a Cholesky decomposition (and optionally the retry in
Expand All @@ -430,7 +463,10 @@ def _compute_inverse_factors(
# Kronecker-factored eigenbasis (KFE).
aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT)
ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT)
return (aaT_eigvecs, aaT_eigvals), (ggT_eigvecs, ggT_eigvals)
inv_damped_eigenvalues = self._compute_inv_damped_eigenvalues(
aaT_eigvals, ggT_eigvals, name
)
return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues
else:
damping_aaT, damping_ggT = self._compute_damping(aaT, ggT)

Expand Down Expand Up @@ -476,129 +512,50 @@ def _compute_inverse_factors(
raise error
ggT_inv = None if ggT_chol is None else cholesky_inverse(ggT_chol)

return aaT_inv, ggT_inv
return aaT_inv, ggT_inv, None

def _compute_or_get_cached_inverse(
self, name: str
) -> Tuple[KFACInvType, KFACInvType]:
) -> Tuple[FactorType, FactorType, Optional[Tensor]]:
"""Invert the Kronecker factors of the KFACLinearOperator or retrieve them.

Args:
name: Name of the layer for which to invert Kronecker factors.

Returns:
Tuple of inverses (or eigendecompositions) of the input and gradient
covariance Kronecker factors. Can be ``None`` if the input or gradient
covariance is ``None`` (e.g. the input covariances for biases).
covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if
the input or gradient covariance is ``None`` (e.g. the input covariances for
biases).
"""
if name in self._inverse_input_covariances:
aaT_inv = self._inverse_input_covariances.get(name)
ggT_inv = self._inverse_gradient_covariances.get(name)
return aaT_inv, ggT_inv
return aaT_inv, ggT_inv, None

if self._A._correct_eigenvalues:
aaT_eigenvecs = self._A._input_covariances_eigenvectors.get(name)
ggT_eigenvecs = self._A._gradient_covariances_eigenvectors.get(name)
eigenvalues = self._A._corrected_eigenvalues.get(name)
if isinstance(eigenvalues, dict):
inv_damped_eigenvalues = {}
for key, val in eigenvalues.items():
inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1)
elif isinstance(eigenvalues, Tensor):
inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1)
return aaT_eigenvecs, ggT_eigenvecs, inv_damped_eigenvalues

aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv, ggT_inv = self._compute_inverse_factors(aaT, ggT)
aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors(
aaT, ggT, name
)

if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv

return aaT_inv, ggT_inv

def _left_and_right_multiply(
self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType
) -> Tensor:
"""Left and right multiply matrix with inverse Kronecker factors.

Args:
M_joint: Matrix for multiplication.
aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for
biases.
ggT_inv: Inverse of the gradient covariance Kronecker factor.

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M_joint``.
"""
if self._use_exact_damping:
# Perform damped preconditioning in KFE, e.g. see equation (21) in
# https://arxiv.org/abs/2308.03296.
aaT_eigvecs, aaT_eigvals = aaT_inv
ggT_eigvecs, ggT_eigvals = ggT_inv
# Transform in eigenbasis.
M_joint = einsum(
ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l"
)
# Divide by damped eigenvalues to perform the inversion.
M_joint.div_(outer(ggT_eigvals, aaT_eigvals).add_(self._damping))
# Transform back to standard basis.
M_joint = einsum(
ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l"
)
else:
M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j, m j k, k l -> m i l")
return M_joint

def _separate_left_and_right_multiply(
self,
M_torch: Tensor,
param_pos: Dict[str, int],
aaT_inv: KFACInvType,
ggT_inv: KFACInvType,
) -> Tensor:
"""Multiply matrix with inverse Kronecker factors for separated weight and bias.

Args:
M_torch: Matrix for multiplication.
param_pos: Dictionary with positions of the weight and bias parameters.
aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for
biases.
ggT_inv: Inverse of the gradient covariance Kronecker factor.

Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M_torch``.
"""
if self._use_exact_damping:
# Perform damped preconditioning in KFE, e.g. see equation (21) in
# https://arxiv.org/abs/2308.03296.
aaT_eigvecs, aaT_eigvals = aaT_inv
ggT_eigvecs, ggT_eigvals = ggT_inv

for p_name, pos in param_pos.items():
# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
if p_name == "weight":
M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)")
aaT_fac = aaT_eigvecs if self._use_exact_damping else aaT_inv
# If `use_exact_damping` is `True`, we transform to eigenbasis
M_torch[pos] = einsum(M_w, aaT_fac, "m i j, j k -> m i k")

ggT_fac = ggT_eigvecs if self._use_exact_damping else ggT_inv
dims = (
"m i ... -> m j ..."
if self._use_exact_damping
else " m j ... -> m i ..."
)
# If `use_exact_damping` is `True`, we transform to eigenbasis
M_torch[pos] = einsum(ggT_fac, M_torch[pos], f"i j, {dims}")

if self._use_exact_damping:
# Divide by damped eigenvalues to perform the inversion and transform
# back to standard basis.
if p_name == "weight":
M_torch[pos].div_(
outer(ggT_eigvals, aaT_eigvals).add_(self._damping)
)
M_torch[pos] = einsum(
M_torch[pos], aaT_eigvecs, "m i j, k j -> m i k"
)
else:
M_torch[pos].div_(ggT_eigvals.add_(self._damping))
M_torch[pos] = einsum(
ggT_eigvecs, M_torch[pos], "i j, m j ... -> m i ..."
)

return M_torch
return aaT_inv, ggT_inv, inv_damped_eigenvalues

def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.
Expand All @@ -621,12 +578,19 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch)
if not self._A._input_covariances and not self._A._gradient_covariances:
if (
not self._A._input_covariances
and not self._A._gradient_covariances
and not self._A._input_covariances_eigenvectors
and not self._A._gradient_covariances_eigenvectors
):
self._A._compute_kfac()

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
aaT_inv, ggT_inv, inv_damped_eigenvalues = (
self._compute_or_get_cached_inverse(mod_name)
)
# cache the weight shape to ensure correct shapes are returned
if "weight" in param_pos:
weight_shape = M_torch[param_pos["weight"]].shape
Expand All @@ -640,12 +604,14 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)")
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2)
M_joint = self._left_and_right_multiply(M_joint, aaT_inv, ggT_inv)
M_joint = self._A._left_and_right_multiply(
M_joint, aaT_inv, ggT_inv, inv_damped_eigenvalues
)
w_cols = M_w.shape[2]
M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2)
else:
M_torch = self._separate_left_and_right_multiply(
M_torch, param_pos, aaT_inv, ggT_inv
M_torch = self._A._separate_left_and_right_multiply(
M_torch, param_pos, aaT_inv, ggT_inv, inv_damped_eigenvalues
)

# restore original shapes
Expand Down
Loading
Loading