From cf45ef5676247ad8af5bffb4d0eb55d05db638a3 Mon Sep 17 00:00:00 2001 From: runame Date: Sat, 21 Sep 2024 14:52:21 -0400 Subject: [PATCH] Address KFAC refactor suggestions --- curvlinops/kfac.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index a747b10..d41ee0e 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -261,7 +261,7 @@ def __init__( "Only mc_samples=1 is supported for `fisher_type != FisherType.MC`." ) if fisher_type == FisherType.FORWARD_ONLY and correct_eigenvalues: - raise ValueError( + raise NotImplementedError( "Correcting eigenvalues is not supported for FisherType.FORWARD_ONLY." ) if kfac_approx not in self._SUPPORTED_KFAC_APPROX: @@ -277,6 +277,8 @@ def __init__( self._mc_samples = mc_samples self._kfac_approx = kfac_approx self._correct_eigenvalues = correct_eigenvalues + # Initialize flag which determines whether to compute the KFAC factors or the + # eigenvalue correction in the forward-backward pass(es) self._compute_eigenvalue_correction_flag = False self._input_covariances: Dict[str, Tensor] = {} self._gradient_covariances: Dict[str, Tensor] = {} @@ -457,7 +459,9 @@ def _left_and_right_multiply( aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for biases. ggT: Gradient covariance Kronecker factor or its eigenvectors. - eigenvalues: Corrected eigenvalues for the EKFAC approximation. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. Returns: Matrix-multiplication result ``KFAC @ M_joint``. @@ -473,7 +477,7 @@ def _left_and_right_multiply( M_joint = einsum( ggT_eigvecs, M_joint, aaT_eigvecs, "i j, m i k, k l -> m j l" ) - # Multiply by eigenvalues. + # Multiply (broadcasted) by eigenvalues. M_joint.mul_(eigenvalues) # Transform back to standard basis. M_joint = einsum( @@ -497,7 +501,9 @@ def _separate_left_and_right_multiply( aaT: Input covariance Kronecker factor or its eigenvectors. ``None`` for biases. ggT: Gradient covariance Kronecker factor or its eigenvectors. - eigenvalues: Corrected eigenvalues for the EKFAC approximation. + eigenvalues: Eigenvalues of the (E)KFAC approximation when multiplying with + the eigendecomposition of the KFAC approximation. ``None`` for the + non-decomposed KFAC approximation. Returns: Matrix-multiplication result ``KFAC @ M_torch``. @@ -510,12 +516,13 @@ def _separate_left_and_right_multiply( # If `eigenvalues` is not `None`, we transform to eigenbasis here M_torch[pos] = einsum(M_w, aaT, "m i j, j k -> m i k") - dims = "m j ... -> m i ..." if eigenvalues is None else "m i ... -> m j ..." - # If `eigenvalues` is not `None`, we transform to eigenbasis here - M_torch[pos] = einsum(ggT, M_torch[pos], f"i j, {dims}") + # If `eigenvalues` is not `None`, we convert to eigenbasis here + M_torch[pos] = einsum( + ggT.T if eigenvalues else ggT, M_torch[pos], "i j, m j ... -> m i ..." + ) if eigenvalues is not None: - # Multiply by eigenvalues and transform back to standard basis + # Multiply (broadcasted) by eigenvalues, convert back to original basis M_torch[pos].mul_(eigenvalues[pos]) if p_name == "weight": M_torch[pos] = einsum(M_torch[pos], aaT, "m i j, k j -> m i k") @@ -1101,7 +1108,10 @@ def _set_or_add_( elif isinstance(dictionary[key], Tensor) and isinstance(value, Tensor): dictionary[key].add_(value) else: - raise ValueError("Incompatible types for addition.") + raise ValueError( + "Incompatible types for addition: dictionary value of type " + f"{type(dictionary[key])} and value to be added of type {type(value)}." + ) return dictionary @classmethod