Skip to content

Commit

Permalink
Fix flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 17, 2024
1 parent b096726 commit a2dec74
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 35 deletions.
72 changes: 45 additions & 27 deletions curvlinops/inverse.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
"""Implements linear operator inverses."""

from math import sqrt
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
from warnings import warn

from einops import einsum, rearrange
from einops import rearrange
from numpy import allclose, column_stack, ndarray
from scipy.sparse.linalg import LinearOperator, cg, lsmr
from torch import Tensor, cat, cholesky_inverse, eye, float64, outer
from torch.linalg import cholesky, eigh

from curvlinops.kfac import KFACLinearOperator, KFACType, ParameterMatrixType
from curvlinops.kfac import FactorType, KFACLinearOperator, ParameterMatrixType


class _InverseLinearOperator(LinearOperator):
Expand Down Expand Up @@ -360,8 +360,8 @@ def __init__(
self._use_exact_damping = use_exact_damping
self._cache = cache
self._retry_double_precision = retry_double_precision
self._inverse_input_covariances: Dict[str, KFACType] = {}
self._inverse_gradient_covariances: Dict[str, KFACType] = {}
self._inverse_input_covariances: Dict[str, FactorType] = {}
self._inverse_gradient_covariances: Dict[str, FactorType] = {}

def _compute_damping(
self, aaT: Optional[Tensor], ggT: Optional[Tensor]
Expand Down Expand Up @@ -405,9 +405,44 @@ def _damped_cholesky(self, M: Tensor, damping: float) -> Tensor:
M.add(eye(M.shape[0], dtype=M.dtype, device=M.device), alpha=damping)
)

def _compute_inv_damped_eigenvalues(
self, aaT_eigvals: Tensor, ggT_eigvals: Tensor, name: str
) -> Union[Tensor, Dict[str, Tensor]]:
"""Compute the inverses of the damped eigenvalues for a given layer.
Args:
aaT_eigvals: Eigenvalues of the input covariance matrix.
ggT_eigvals: Eigenvalues of the gradient covariance matrix.
name: Name of the layer for which to damp and invert eigenvalues.
Returns:
Inverses of the damped eigenvalues.
"""
param_pos = self._A._mapping[name]
if (
not self._A._separate_weight_and_bias
and "weight" in param_pos
and "bias" in param_pos
):
inv_damped_eigenvalues = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues = {}
for p_name, pos in param_pos.items():
if p_name == "weight":
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues[pos] = ggT_eigvals.add(self._damping).pow_(
-1
)
return inv_damped_eigenvalues

def _compute_inverse_factors(
self, aaT: Optional[Tensor], ggT: Optional[Tensor], name: str
) -> Tuple[KFACType, KFACType, Optional[Tensor]]:
) -> Tuple[FactorType, FactorType, Optional[Tensor]]:
"""Compute the inverses of the Kronecker factors for a given layer.
Args:
Expand All @@ -430,26 +465,9 @@ def _compute_inverse_factors(
# Kronecker-factored eigenbasis (KFE).
aaT_eigvals, aaT_eigvecs = (None, None) if aaT is None else eigh(aaT)
ggT_eigvals, ggT_eigvecs = (None, None) if ggT is None else eigh(ggT)
param_pos = self._A._mapping[name]
if (
not self._A._separate_weight_and_bias
and "weight" in param_pos
and "bias" in param_pos
):
inv_damped_eigenvalues = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues = {}
for p_name, pos in param_pos.items():
if p_name == "weight":
inv_damped_eigenvalues[pos] = (
outer(ggT_eigvals, aaT_eigvals).add_(self._damping).pow_(-1)
)
else:
inv_damped_eigenvalues[pos] = ggT_eigvals.add(
self._damping
).pow_(-1)
inv_damped_eigenvalues = self._compute_inv_damped_eigenvalues(
aaT_eigvals, ggT_eigvals, name
)
return aaT_eigvecs, ggT_eigvecs, inv_damped_eigenvalues
else:
damping_aaT, damping_ggT = self._compute_damping(aaT, ggT)
Expand Down Expand Up @@ -500,7 +518,7 @@ def _compute_inverse_factors(

def _compute_or_get_cached_inverse(
self, name: str
) -> Tuple[KFACType, KFACType, Optional[Tensor]]:
) -> Tuple[FactorType, FactorType, Optional[Tensor]]:
"""Invert the Kronecker factors of the KFACLinearOperator or retrieve them.
Args:
Expand Down
18 changes: 11 additions & 7 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@
# shape as the parameters, or a single matrix/vector of shape `[D, D]`/`[D]` where `D`
# is the number of parameters.
ParameterMatrixType = TypeVar("ParameterMatrixType", Tensor, List[Tensor])
KFACType = TypeVar(
"KFACType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
FactorType = TypeVar(
"FactorType", Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]
)


Expand Down Expand Up @@ -438,8 +438,8 @@ def _check_input_type_and_preprocess(
@staticmethod
def _left_and_right_multiply(
M_joint: Tensor,
aaT: KFACType,
ggT: KFACType,
aaT: FactorType,
ggT: FactorType,
eigenvalues: Optional[Tensor],
) -> Tensor:
"""Left and right multiply matrix with Kronecker factors.
Expand Down Expand Up @@ -477,8 +477,8 @@ def _left_and_right_multiply(
def _separate_left_and_right_multiply(
M_torch: Tensor,
param_pos: Dict[str, int],
aaT: KFACType,
ggT: KFACType,
aaT: FactorType,
ggT: FactorType,
eigenvalues: Optional[Tensor],
) -> Tensor:
"""Multiply matrix with Kronecker factors for separated weight and bias.
Expand Down Expand Up @@ -933,7 +933,7 @@ def _accumulate_gradient_covariance(
def _compute_eigenvalue_correction(
self, module_name: str, g: Tensor, correction: int
):
"""Compute the corrected eigenvalues for the EKFAC approximation.
r"""Compute the corrected eigenvalues for the EKFAC approximation.
The corrected eigenvalues are computed as
:math:`\lambda_{\text{corrected}} = (Q_g^T G Q_a)^2`, where
Expand Down Expand Up @@ -1083,6 +1083,10 @@ def _set_or_add_(
Returns:
The updated dictionary.
Raises:
ValueError: If the types of the value and the dictionary entry are
incompatible.
"""
if key not in dictionary:
dictionary[key] = value
Expand Down
1 change: 0 additions & 1 deletion test/test_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@

from curvlinops import EFLinearOperator, GGNLinearOperator
from curvlinops.examples.utils import report_nonclose
from curvlinops.gradient_moments import EFLinearOperator
from curvlinops.kfac import FisherType, KFACLinearOperator, KFACType


Expand Down

0 comments on commit a2dec74

Please sign in to comment.