Skip to content

Commit

Permalink
Address KFAC refactor suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
runame committed Sep 21, 2024
1 parent 89c814f commit cf45ef5
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
"Only mc_samples=1 is supported for `fisher_type != FisherType.MC`."
)
if fisher_type == FisherType.FORWARD_ONLY and correct_eigenvalues:
raise ValueError(
raise NotImplementedError(
"Correcting eigenvalues is not supported for FisherType.FORWARD_ONLY."
)
if kfac_approx not in self._SUPPORTED_KFAC_APPROX:
Expand All @@ -277,6 +277,8 @@ def __init__(
self._mc_samples = mc_samples
self._kfac_approx = kfac_approx
self._correct_eigenvalues = correct_eigenvalues
# Initialize flag which determines whether to compute the KFAC factors or the
# eigenvalue correction in the forward-backward pass(es)
self._compute_eigenvalue_correction_flag = False
self._input_covariances: Dict[str, Tensor] = {}
self._gradient_covariances: Dict[str, Tensor] = {}
Expand Down Expand Up @@ -457,7 +459,9 @@ def _left_and_right_multiply(
aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for
biases.
ggT: Gradient covariance Kronecker factor or its eigenvectors.
eigenvalues: Corrected eigenvalues for the EKFAC approximation.
eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with
the eigendecomposition of the KFAC approximation. ``None`` for the
non-decomposed KFAC approximation.
Returns:
Matrix-multiplication result ``KFAC @ M_joint``.
Expand All @@ -473,7 +477,7 @@ def _left_and_right_multiply(
M_joint = einsum(
ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l"
)
# Multiply by eigenvalues.
# Multiply (broadcasted) by eigenvalues.
M_joint.mul_(eigenvalues)
# Transform back to standard basis.
M_joint = einsum(
Expand All @@ -497,7 +501,9 @@ def _separate_left_and_right_multiply(
aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for
biases.
ggT: Gradient covariance Kronecker factor or its eigenvectors.
eigenvalues: Corrected eigenvalues for the EKFAC approximation.
eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with
the eigendecomposition of the KFAC approximation. ``None`` for the
non-decomposed KFAC approximation.
Returns:
Matrix-multiplication result ``KFAC @ M_torch``.
Expand All @@ -510,12 +516,13 @@ def _separate_left_and_right_multiply(
# If `eigenvalues` is not `None`, we transform to eigenbasis here
M_torch[pos] = einsum(M_w, aaT, "m i j, j k -> m i k")

dims = "m j ... -> m i ..." if eigenvalues is None else "m i ... -> m j ..."
# If `eigenvalues` is not `None`, we transform to eigenbasis here
M_torch[pos] = einsum(ggT, M_torch[pos], f"i j, {dims}")
# If `eigenvalues` is not `None`, we convert to eigenbasis here
M_torch[pos] = einsum(
ggT.T if eigenvalues else ggT, M_torch[pos], "i j, m j ... -> m i ..."
)

if eigenvalues is not None:
# Multiply by eigenvalues and transform back to standard basis
# Multiply (broadcasted) by eigenvalues, convert back to original basis
M_torch[pos].mul_(eigenvalues[pos])
if p_name == "weight":
M_torch[pos] = einsum(M_torch[pos], aaT, "m i j, k j -> m i k")
Expand Down Expand Up @@ -1101,7 +1108,10 @@ def _set_or_add_(
elif isinstance(dictionary[key], Tensor) and isinstance(value, Tensor):
dictionary[key].add_(value)
else:
raise ValueError("Incompatible types for addition.")
raise ValueError(
"Incompatible types for addition: dictionary value of type "
f"{type(dictionary[key])} and value to be added of type {type(value)}."
)
return dictionary

@classmethod
Expand Down

0 comments on commit cf45ef5

Please sign in to comment.