diff --git a/holocron/optim/ademamix.py b/holocron/optim/ademamix.py index 0bfb17f0..422c05a0 100644 --- a/holocron/optim/ademamix.py +++ b/holocron/optim/ademamix.py @@ -8,12 +8,12 @@ import torch from torch import Tensor -from torch.optim import Adam +from torch.optim import Optimizer __all__ = ["AdEMAMix", "ademamix"] -class AdEMAMix(Adam): +class AdEMAMix(Optimizer): r"""Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" `_. The estimation of momentums is described as follows, :math:`\forall t \geq 1`: @@ -59,10 +59,16 @@ def __init__( alpha: float = 5.0, eps: float = 1e-8, weight_decay: float = 0.0, - amsgrad: bool = False, ) -> None: - super().__init__(params, lr, betas, eps, weight_decay, amsgrad) # type: ignore[arg-type] - self.alpha = alpha + if lr < 0.0: + raise ValueError(f"Invalid learning rate: {lr}") + if eps < 0.0: + raise ValueError(f"Invalid epsilon value: {eps}") + for idx, beta in enumerate(betas): + if not 0.0 <= beta < 1.0: + raise ValueError(f"Invalid beta parameter at index {idx}: {beta}") + defaults = {"lr": lr, "betas": betas, "alpha": alpha, "eps": eps, "weight_decay": weight_decay} + super().__init__(params, defaults) @torch.no_grad() def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override] @@ -79,10 +85,9 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] for group in self.param_groups: params_with_grad = [] grads = [] - m1 = [] - m2 = [] - nu = [] - max_nu = [] + exp_avgs = [] + exp_avgs_slow = [] + exp_avg_sqs = [] state_steps = [] for p in group["params"]: @@ -97,20 +102,14 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] if len(state) == 0: state["step"] = 0 # Exponential moving average of gradient values - state["m1"] = torch.zeros_like(p, memory_format=torch.preserve_format) - state["m2"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_slow"] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values - state["nu"] = torch.zeros_like(p, memory_format=torch.preserve_format) - if group["amsgrad"]: - # Maintains max of all exp. moving avg. of sq. grad. values - state["max_nu"] = torch.zeros_like(p, memory_format=torch.preserve_format) + state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format) - m1.append(state["m1"]) - m2.append(state["m2"]) - nu.append(state["nu"]) - - if group["amsgrad"]: - max_nu.append(state["max_nu"]) + exp_avgs.append(state["exp_avg"]) + exp_avgs_slow.append(state["exp_avg_slow"]) + exp_avg_sqs.append(state["exp_avg_sq"]) # update the steps for each param group update state["step"] += 1 @@ -121,16 +120,14 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] ademamix( params_with_grad, grads, - m1, - m2, - nu, - max_nu, + exp_avgs, + exp_avgs_slow, + exp_avg_sqs, state_steps, - group["amsgrad"], beta1, beta2, beta3, - self.alpha, + group["alpha"], group["lr"], group["weight_decay"], group["eps"], @@ -141,12 +138,10 @@ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float] def ademamix( params: List[Tensor], grads: List[Tensor], - m1s: List[Tensor], - m2s: List[Tensor], - nus: List[Tensor], - max_nus: List[Tensor], + exp_avgs: List[Tensor], + exp_avgs_slow: List[Tensor], + exp_avg_sqs: List[Tensor], state_steps: List[int], - amsgrad: bool, beta1: float, beta2: float, beta3: float, @@ -160,12 +155,10 @@ def ademamix( """ for i, param in enumerate(params): grad = grads[i] - m1 = m1s[i] - m2 = m2s[i] - nu = nus[i] + m1 = exp_avgs[i] + m2 = exp_avgs_slow[i] + nu = exp_avg_sqs[i] step = state_steps[i] - if amsgrad: - max_nu = max_nus[i] bias_correction1 = 1 - beta1**step bias_correction2 = 1 - beta2**step @@ -175,16 +168,9 @@ def ademamix( # Decay the first and second moment running average coefficient m1.mul_(beta1).add_(grad, alpha=1 - beta1) + nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) m2.mul_(beta3).add_(grad, alpha=1 - beta3) - grad_residual = grad - m1 - nu.mul_(beta2).addcmul_(grad_residual, grad_residual, value=1 - beta2) - - if amsgrad: - # Maintains the maximum of all 2nd moment running avg. till now - torch.maximum(max_nu, m2, out=max_nu) - # Use the max. for normalizing running avg. of gradient - denom = (max_nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) - else: - denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) + + denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps) param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr)