Skip to content

Commit

Permalink
[ADD] Draft exponential moving average for KFAC
Browse files Browse the repository at this point in the history
  • Loading branch information
f-dangel committed Nov 1, 2024
1 parent 2b7a745 commit 58b3283
Showing 1 changed file with 72 additions and 0 deletions.
72 changes: 72 additions & 0 deletions curvlinops/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -1233,3 +1233,75 @@ def from_state_dict(
kfac.to_device(old_device)

return kfac

def exponential_moving_average_(
self, other: KFACLinearOperator, beta: float
) -> None:
"""Incorporate the Kronecker matrices of another KFAC using EMA.
Updates the Kronecker factors of the current KFAC linear operator in-place.
Args:
other: The other KFAC linear operator whose Kronecker factors are
incorporated.
beta: The EMA decay factor. Must be in ``[0, 1]``. ``1`` means the
Kronecker factors are not updated, and ``0`` means the Kronecker factors
are replaced by the other KFAC's Kronecker factors.
Raises:
ValueError: If ``beta`` is not in ``[0, 1]``.
ValueError: If the Kronecker factors have incompatible data formats.
"""
if not 0 <= beta <= 1:
raise ValueError("Beta must be in [0, 1].")

# make sure both KFACs have computed the Kronecker matrices
if not self._input_covariances and not self._gradient_covariances:
self._compute_kfac()
if not other._input_covariances and not other._gradient_covariances:
other._compute_kfac()

# make sure the Kronecker matrices have the same keys and shapes
keys = set(self._input_covariances.keys())
other_keys = set(other._input_covariances.keys())
if keys != other_keys:
raise ValueError(
f"Input covariance keys do not match: {keys} != {other_keys}."
)
for key in keys:
aaT = self._input_covariances[key]
aaT_other = other._input_covariances[key]
if aaT.shape != aaT_other.shape:
raise ValueError(
f"Input covariance shapes for {key!r} do not match:"
f" {aaT.shape} != {aaT_other.shape}."
)

keys = set(self._gradient_covariances.keys())
other_keys = set(other._gradient_covariances.keys())
if keys != other_keys:
raise ValueError(
f"Gradient covariance keys do not match: {keys} != {other_keys}."
)
for key in keys:
ggT = self._gradient_covariances[key]
ggT_other = other._gradient_covariances[key]
if ggT.shape != ggT_other.shape:
raise ValueError(
f"Gradient covariance shapes for {key!r} do not match:"
f" {ggT.shape} != {ggT_other.shape}."
)

# apply exponential moving average
if self._input_covariances and self._gradient_covariances:
for key in self._input_covariances.keys():
self._input_covariances[key].mul_(beta).add_(
other._input_covariances[key], alpha=1 - beta
)
for key in self._gradient_covariances.keys():
self._gradient_covariances[key].mul_(beta).add_(
other._gradient_covariances[key], alpha=1 - beta
)

self._reset_matrix_properties()

0 comments on commit 58b3283

Please sign in to comment.