Skip to content

Commit

Permalink
fix(optim): fix AdEMAMix
Browse files Browse the repository at this point in the history
  • Loading branch information
frgfm committed Sep 9, 2024
1 parent 2db9ed5 commit 1ea5548
Showing 1 changed file with 33 additions and 47 deletions.
80 changes: 33 additions & 47 deletions holocron/optim/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

import torch
from torch import Tensor
from torch.optim import Adam
from torch.optim import Optimizer

__all__ = ["AdEMAMix", "ademamix"]


class AdEMAMix(Adam):
class AdEMAMix(Optimizer):
r"""Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" <https://arxiv.org/pdf/2409.03137>`_.
The estimation of momentums is described as follows, :math:`\forall t \geq 1`:
Expand Down Expand Up @@ -59,10 +59,16 @@ def __init__(
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 0.0,
amsgrad: bool = False,
) -> None:
super().__init__(params, lr, betas, eps, weight_decay, amsgrad) # type: ignore[arg-type]
self.alpha = alpha
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
for idx, beta in enumerate(betas):
if not 0.0 <= beta < 1.0:
raise ValueError(f"Invalid beta parameter at index {idx}: {beta}")
defaults = {"lr": lr, "betas": betas, "alpha": alpha, "eps": eps, "weight_decay": weight_decay}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
Expand All @@ -79,10 +85,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
for group in self.param_groups:
params_with_grad = []
grads = []
m1 = []
m2 = []
nu = []
max_nu = []
exp_avgs = []
exp_avgs_slow = []
exp_avg_sqs = []
state_steps = []

for p in group["params"]:
Expand All @@ -97,20 +102,14 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["m1"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["m2"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_slow"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["nu"] = torch.zeros_like(p, memory_format=torch.preserve_format)
if group["amsgrad"]:
# Maintains max of all exp. moving avg. of sq. grad. values
state["max_nu"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

m1.append(state["m1"])
m2.append(state["m2"])
nu.append(state["nu"])

if group["amsgrad"]:
max_nu.append(state["max_nu"])
exp_avgs.append(state["exp_avg"])
exp_avgs_slow.append(state["exp_avg_slow"])
exp_avg_sqs.append(state["exp_avg_sq"])

# update the steps for each param group update
state["step"] += 1
Expand All @@ -121,16 +120,14 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
ademamix(
params_with_grad,
grads,
m1,
m2,
nu,
max_nu,
exp_avgs,
exp_avgs_slow,
exp_avg_sqs,
state_steps,
group["amsgrad"],
beta1,
beta2,
beta3,
self.alpha,
group["alpha"],
group["lr"],
group["weight_decay"],
group["eps"],
Expand All @@ -141,12 +138,10 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]
def ademamix(
params: List[Tensor],
grads: List[Tensor],
m1s: List[Tensor],
m2s: List[Tensor],
nus: List[Tensor],
max_nus: List[Tensor],
exp_avgs: List[Tensor],
exp_avgs_slow: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
amsgrad: bool,
beta1: float,
beta2: float,
beta3: float,
Expand All @@ -160,12 +155,10 @@ def ademamix(
"""
for i, param in enumerate(params):
grad = grads[i]
m1 = m1s[i]
m2 = m2s[i]
nu = nus[i]
m1 = exp_avgs[i]
m2 = exp_avgs_slow[i]
nu = exp_avg_sqs[i]
step = state_steps[i]
if amsgrad:
max_nu = max_nus[i]

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step
Expand All @@ -175,16 +168,9 @@ def ademamix(

# Decay the first and second moment running average coefficient
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)
grad_residual = grad - m1
nu.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2)

if amsgrad:
# Maintains the maximum of all 2nd moment running avg. till now
torch.maximum(max_nu, m2, out=max_nu)
# Use the max. for normalizing running avg. of gradient
denom = (max_nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)
else:
denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)

denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)

param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr)

0 comments on commit 1ea5548

Please sign in to comment.