From 58b328330a8b09b2f9f4792e293bbf61c04de684 Mon Sep 17 00:00:00 2001 From: Felix Dangel Date: Fri, 1 Nov 2024 09:56:54 -0400 Subject: [PATCH] [ADD] Draft exponential moving average for KFAC --- curvlinops/kfac.py | 72 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/curvlinops/kfac.py b/curvlinops/kfac.py index 413c86b..f6513a5 100644 --- a/curvlinops/kfac.py +++ b/curvlinops/kfac.py @@ -1233,3 +1233,75 @@ def from_state_dict( kfac.to_device(old_device) return kfac + + def exponential_moving_average_( + self, other: KFACLinearOperator, beta: float + ) -> None: + """Incorporate the Kronecker matrices of another KFAC using EMA. + + Updates the Kronecker factors of the current KFAC linear operator in-place. + + Args: + other: The other KFAC linear operator whose Kronecker factors are + incorporated. + beta: The EMA decay factor. Must be in ``[0, 1]``. ``1`` means the + Kronecker factors are not updated, and ``0`` means the Kronecker factors + are replaced by the other KFAC's Kronecker factors. + + Raises: + ValueError: If ``beta`` is not in ``[0, 1]``. + ValueError: If the Kronecker factors have incompatible data formats. + + """ + if not 0 <= beta <= 1: + raise ValueError("Beta must be in [0, 1].") + + # make sure both KFACs have computed the Kronecker matrices + if not self._input_covariances and not self._gradient_covariances: + self._compute_kfac() + if not other._input_covariances and not other._gradient_covariances: + other._compute_kfac() + + # make sure the Kronecker matrices have the same keys and shapes + keys = set(self._input_covariances.keys()) + other_keys = set(other._input_covariances.keys()) + if keys != other_keys: + raise ValueError( + f"Input covariance keys do not match: {keys} != {other_keys}." + ) + for key in keys: + aaT = self._input_covariances[key] + aaT_other = other._input_covariances[key] + if aaT.shape != aaT_other.shape: + raise ValueError( + f"Input covariance shapes for {key!r} do not match:" + f" {aaT.shape} != {aaT_other.shape}." + ) + + keys = set(self._gradient_covariances.keys()) + other_keys = set(other._gradient_covariances.keys()) + if keys != other_keys: + raise ValueError( + f"Gradient covariance keys do not match: {keys} != {other_keys}." + ) + for key in keys: + ggT = self._gradient_covariances[key] + ggT_other = other._gradient_covariances[key] + if ggT.shape != ggT_other.shape: + raise ValueError( + f"Gradient covariance shapes for {key!r} do not match:" + f" {ggT.shape} != {ggT_other.shape}." + ) + + # apply exponential moving average + if self._input_covariances and self._gradient_covariances: + for key in self._input_covariances.keys(): + self._input_covariances[key].mul_(beta).add_( + other._input_covariances[key], alpha=1 - beta + ) + for key in self._gradient_covariances.keys(): + self._gradient_covariances[key].mul_(beta).add_( + other._gradient_covariances[key], alpha=1 - beta + ) + + self._reset_matrix_properties()