Skip to content

Commit

Permalink
Add inverse EKFAC support
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 17, 2024
1 parent 18062ef commit b096726
Showing 1 changed file with 72 additions and 123 deletions.
195 changes: 72 additions & 123 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,7 @@
from torch import Tensor, cat, cholesky_inverse, eye, float64, outer
from torch.linalg import cholesky, eigh

from curvlinops.kfac import KFACLinearOperator, ParameterMatrixType

KFACInvType = TypeVar(
"KFACInvType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
)
from curvlinops.kfac import KFACLinearOperator, KFACType, ParameterMatrixType


class _InverseLinearOperator(LinearOperator):
Expand Down Expand Up @@ -355,15 +351,17 @@ def __init__(
raise ValueError(
"Heuristic and exact damping require a single damping value."
)
if self._A._correct_eigenvalues and not use_exact_damping:
raise ValueError("Only exact damping is supported for EKFAC.")

self._damping = damping
self._use_heuristic_damping = use_heuristic_damping
self._min_damping = min_damping
self._use_exact_damping = use_exact_damping
self._cache = cache
self._retry_double_precision = retry_double_precision
self._inverse_input_covariances: Dict[str, KFACInvType] = {}
self._inverse_gradient_covariances: Dict[str, KFACInvType] = {}
self._inverse_input_covariances: Dict[str, KFACType] = {}
self._inverse_gradient_covariances: Dict[str, KFACType] = {}

def _compute_damping(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
Expand Down Expand Up @@ -408,18 +406,20 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor:
)

def _compute_inverse_factors(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
) -> Tuple[KFACInvType, KFACInvType]:
self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str
) -> Tuple[KFACType, KFACType, Optional[Tensor]]:
"""Compute the inverses of the Kronecker factors for a given layer.
Args:
aaT: Input covariance matrix. ``None`` for biases.
ggT: Gradient covariance matrix.
name: Name of the layer for which to invert Kronecker factors.
Returns:
Tuple of inverses (or eigendecompositions) of the input and gradient
covariance Kronecker factors. Can be ``None`` if the input or gradient
covariance is ``None`` (e.g. the input covariances for biases).
covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if
the input or gradient covariance is ``None`` (e.g. the input covariances for
biases).
Raises:
RuntimeError: If a Cholesky decomposition (and optionally the retry in
Expand All @@ -430,7 +430,27 @@ def _compute_inverse_factors(
# Kronecker-factored eigenbasis (KFE).
aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT)
ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT)
return (aaT_eigvecs, aaT_eigvals), (ggT_eigvecs, ggT_eigvals)
param_pos = self._A._mapping[name]
if (
not self._A._separate_weight_and_bias
and "weight" in param_pos
and "bias" in param_pos
):
inv_damped_eigenvalues = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues = {}
for p_name, pos in param_pos.items():
if p_name == "weight":
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues[pos] = ggT_eigvals.add(
self._damping
).pow_(-1)
return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues
else:
damping_aaT, damping_ggT = self._compute_damping(aaT, ggT)

Expand Down Expand Up @@ -476,129 +496,49 @@ def _compute_inverse_factors(
raise error
ggT_inv = None if ggT_chol is None else cholesky_inverse(ggT_chol)

return aaT_inv, ggT_inv
return aaT_inv, ggT_inv, None

def _compute_or_get_cached_inverse(
self, name: str
) -> Tuple[KFACInvType, KFACInvType]:
) -> Tuple[KFACType, KFACType, Optional[Tensor]]:
"""Invert the Kronecker factors of the KFACLinearOperator or retrieve them.
Args:
name: Name of the layer for which to invert Kronecker factors.
Returns:
Tuple of inverses (or eigendecompositions) of the input and gradient
covariance Kronecker factors. Can be ``None`` if the input or gradient
covariance is ``None`` (e.g. the input covariances for biases).
covariance Kronecker factors and optionally eigenvalues. Can be ``None`` if
the input or gradient covariance is ``None`` (e.g. the input covariances for
biases).
"""
if name in self._inverse_input_covariances:
aaT_inv = self._inverse_input_covariances.get(name)
ggT_inv = self._inverse_gradient_covariances.get(name)
return aaT_inv, ggT_inv

aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv, ggT_inv = self._compute_inverse_factors(aaT, ggT)

if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv

return aaT_inv, ggT_inv

def _left_and_right_multiply(
self, M_joint: Tensor, aaT_inv: KFACInvType, ggT_inv: KFACInvType
) -> Tensor:
"""Left and right multiply matrix with inverse Kronecker factors.
Args:
M_joint: Matrix for multiplication.
aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for
biases.
ggT_inv: Inverse of the gradient covariance Kronecker factor.
Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M_joint``.
"""
if self._use_exact_damping:
# Perform damped preconditioning in KFE, e.g. see equation (21) in
# https://arxiv.org/abs/2308.03296.
aaT_eigvecs, aaT_eigvals = aaT_inv
ggT_eigvecs, ggT_eigvals = ggT_inv
# Transform in eigenbasis.
M_joint = einsum(
ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l"
)
# Divide by damped eigenvalues to perform the inversion.
M_joint.div_(outer(ggT_eigvals, aaT_eigvals).add_(self._damping))
# Transform back to standard basis.
M_joint = einsum(
ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m j k, l k -> m i l"
)
return aaT_inv, ggT_inv, None

if self._A._correct_eigenvalues:
aaT_inv = self._A._input_covariances_eigenvectors.get(name)
ggT_inv = self._A._gradient_covariances_eigenvectors.get(name)
eigenvalues = self._A._corrected_eigenvalues.get(name)
if isinstance(eigenvalues, dict):
inv_damped_eigenvalues = {}
for key, val in eigenvalues.items():
inv_damped_eigenvalues[key] = val.add(self._damping).pow_(-1)
elif isinstance(eigenvalues, Tensor):
inv_damped_eigenvalues = eigenvalues.add(self._damping).pow_(-1)
else:
M_joint = einsum(ggT_inv, M_joint, aaT_inv, "i j, m j k, k l -> m i l")
return M_joint

def _separate_left_and_right_multiply(
self,
M_torch: Tensor,
param_pos: Dict[str, int],
aaT_inv: KFACInvType,
ggT_inv: KFACInvType,
) -> Tensor:
"""Multiply matrix with inverse Kronecker factors for separated weight and bias.
Args:
M_torch: Matrix for multiplication.
param_pos: Dictionary with positions of the weight and bias parameters.
aaT_inv: Inverse of the input covariance Kronecker factor. ``None`` for
biases.
ggT_inv: Inverse of the gradient covariance Kronecker factor.
Returns:
Matrix-multiplication result ``KFAC⁻¹ @ M_torch``.
"""
if self._use_exact_damping:
# Perform damped preconditioning in KFE, e.g. see equation (21) in
# https://arxiv.org/abs/2308.03296.
aaT_eigvecs, aaT_eigvals = aaT_inv
ggT_eigvecs, ggT_eigvals = ggT_inv

for p_name, pos in param_pos.items():
# for weights we need to multiply from the right with aaT
# for weights and biases we need to multiply from the left with ggT
if p_name == "weight":
M_w = rearrange(M_torch[pos], "m c_out ... -> m c_out (...)")
aaT_fac = aaT_eigvecs if self._use_exact_damping else aaT_inv
# If `use_exact_damping` is `True`, we transform to eigenbasis
M_torch[pos] = einsum(M_w, aaT_fac, "m i j, j k -> m i k")

ggT_fac = ggT_eigvecs if self._use_exact_damping else ggT_inv
dims = (
"m i ... -> m j ..."
if self._use_exact_damping
else " m j ... -> m i ..."
aaT = self._A._input_covariances.get(name)
ggT = self._A._gradient_covariances.get(name)
aaT_inv, ggT_inv, inv_damped_eigenvalues = self._compute_inverse_factors(
aaT, ggT, name
)
# If `use_exact_damping` is `True`, we transform to eigenbasis
M_torch[pos] = einsum(ggT_fac, M_torch[pos], f"i j, {dims}")

if self._use_exact_damping:
# Divide by damped eigenvalues to perform the inversion and transform
# back to standard basis.
if p_name == "weight":
M_torch[pos].div_(
outer(ggT_eigvals, aaT_eigvals).add_(self._damping)
)
M_torch[pos] = einsum(
M_torch[pos], aaT_eigvecs, "m i j, k j -> m i k"
)
else:
M_torch[pos].div_(ggT_eigvals.add_(self._damping))
M_torch[pos] = einsum(
ggT_eigvecs, M_torch[pos], "i j, m j ... -> m i ..."
)

return M_torch
if self._cache:
self._inverse_input_covariances[name] = aaT_inv
self._inverse_gradient_covariances[name] = ggT_inv

return aaT_inv, ggT_inv, inv_damped_eigenvalues

def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
"""Apply the inverse of KFAC to a matrix (multiple vectors) in PyTorch.
Expand All @@ -621,12 +561,19 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
``[D, K]`` with some ``K``.
"""
return_tensor, M_torch = self._A._check_input_type_and_preprocess(M_torch)
if not self._A._input_covariances and not self._A._gradient_covariances:
if (
not self._A._input_covariances
and not self._A._gradient_covariances
and not self._A._input_covariances_eigenvectors
and not self._A._gradient_covariances_eigenvectors
):
self._A._compute_kfac()

for mod_name, param_pos in self._A._mapping.items():
# retrieve the inverses of the Kronecker factors from cache or invert them
aaT_inv, ggT_inv = self._compute_or_get_cached_inverse(mod_name)
aaT_inv, ggT_inv, inv_damped_eigenvalues = (
self._compute_or_get_cached_inverse(mod_name)
)
# cache the weight shape to ensure correct shapes are returned
if "weight" in param_pos:
weight_shape = M_torch[param_pos["weight"]].shape
Expand All @@ -640,12 +587,14 @@ def torch_matmat(self, M_torch: ParameterMatrixType) -> ParameterMatrixType:
w_pos, b_pos = param_pos["weight"], param_pos["bias"]
M_w = rearrange(M_torch[w_pos], "m c_out ... -> m c_out (...)")
M_joint = cat([M_w, M_torch[b_pos].unsqueeze(2)], dim=2)
M_joint = self._left_and_right_multiply(M_joint, aaT_inv, ggT_inv)
M_joint = self._A._left_and_right_multiply(
M_joint, aaT_inv, ggT_inv, inv_damped_eigenvalues
)
w_cols = M_w.shape[2]
M_torch[w_pos], M_torch[b_pos] = M_joint.split([w_cols, 1], dim=2)
else:
M_torch = self._separate_left_and_right_multiply(
M_torch, param_pos, aaT_inv, ggT_inv
M_torch = self._A._separate_left_and_right_multiply(
M_torch, param_pos, aaT_inv, ggT_inv, inv_damped_eigenvalues
)

# restore original shapes
Expand Down

0 comments on commit b096726

Please sign in to comment.