Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(optim): add support of AdEMAMix optimizer #373

Merged
merged 10 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/builds.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ jobs:
poetry lock
poetry export -f requirements.txt --without-hashes --output requirements.txt
- name: Build & run docker
run: cd api && docker-compose up -d --build
run: cd api && docker compose up -d --build
- name: Docker sanity check
run: sleep 15 && nc -vz localhost 8050
- name: Ping server
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ jobs:
poetry lock
poetry export -f requirements.txt --without-hashes --with dev --output requirements.txt
- name: Build & run docker
run: cd api && docker-compose up -d --build
run: cd api && docker compose up -d --build
- name: Docker sanity check
run: sleep 15 && nc -vz localhost 8050
- name: Ping server
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pip install -e Holocron/.
- boxes: [Distance-IoU & Complete-IoU losses](https://arxiv.org/abs/1911.08287)

### Trying something else than Adam
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [TAdam](https://arxiv.org/abs/2003.00179), [AdamP](https://arxiv.org/abs/2006.08217), [AdaBelief](https://arxiv.org/abs/2010.07468), [Adan](https://arxiv.org/abs/2208.06677), and customized versions (RaLars)
- Optimizer: [LARS](https://arxiv.org/abs/1708.03888), [Lamb](https://arxiv.org/abs/1904.00962), [TAdam](https://arxiv.org/abs/2003.00179), [AdamP](https://arxiv.org/abs/2006.08217), [AdaBelief](https://arxiv.org/abs/2010.07468), [Adan](https://arxiv.org/abs/2208.06677), and customized versions (RaLars), [AdEMAMix](https://arxiv.org/abs/2409.03137)
- Optimizer wrapper: [Lookahead](https://arxiv.org/abs/1907.08610), Scout (experimental)


Expand Down
2 changes: 1 addition & 1 deletion api/tests/routes/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_classification(test_app_asyncio, mock_classification_image):
response = await test_app_asyncio.post("/classification", files={"file": mock_classification_image})
assert response.status_code == 200
Expand Down
2 changes: 2 additions & 0 deletions docs/source/optim.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ Implementations of recent parameter optimizer for Pytorch modules.

.. autoclass:: Adan

.. autoclass:: AdEMAMix


Optimizer wrappers
------------------
Expand Down
1 change: 1 addition & 0 deletions holocron/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .adabelief import AdaBelief
from .adamp import AdamP
from .adan import Adan
from .ademamix import AdEMAMix

Check notice on line 5 in holocron/optim/__init__.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

holocron/optim/__init__.py#L5

'.ademamix.AdEMAMix' imported but unused (F401)
from .lamb import LAMB
from .lars import LARS
from .ralars import RaLars
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/adabelief.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class AdaBelief(Adam):
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon

where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ s_0 = 0`, :math:`\epsilon > 0`.

Then we correct their biases using:
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/adamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AdamP(Adam):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = g_0,\ v_0 = 0`.

Then we correct their biases using:
Expand Down
176 changes: 176 additions & 0 deletions holocron/optim/ademamix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# Copyright (C) 2024, François-Guillaume Fernandez.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://www.apache.org/licenses/LICENSE-2.0> for full license details.

import math
from typing import Callable, Iterable, List, Optional, Tuple

import torch
from torch import Tensor
from torch.optim import Optimizer

__all__ = ["AdEMAMix", "ademamix"]


class AdEMAMix(Optimizer):
r"""Implements the AdEMAMix optimizer from `"The AdEMAMix Optimizer: Better, Faster, Older" <https://arxiv.org/pdf/2409.03137>`_.

The estimation of momentums is described as follows, :math:`\forall t \geq 1`:

.. math::
m_{1,t} \leftarrow \beta_1 m_{1, t-1} + (1 - \beta_1) g_t \\
m_{2,t} \leftarrow \beta_3 m_{2, t-1} + (1 - \beta_3) g_t \\
s_t \leftarrow \beta_2 s_{t-1} + (1 - \beta_2) (g_t - m_t)^2 + \epsilon

where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2, \beta_3 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`m_{1,0} = 0,\ m_{2,0} = 0,\ s_0 = 0`, :math:`\epsilon > 0`.

Then we correct their biases using:

.. math::
\hat{m_{1,t}} \leftarrow \frac{m_{1,t}}{1 - \beta_1^t} \\
\hat{s_t} \leftarrow \frac{s_t}{1 - \beta_2^t}

And finally the update step is performed using the following rule:

.. math::
\theta_t \leftarrow \theta_{t-1} - \eta \frac{\hat{m_{1,t}} + \alpha m_{2,t}}{\sqrt{\hat{s_t}} + \epsilon}

where :math:`\theta_t` is the parameter value at step :math:`t` (:math:`\theta_0` being the initialization value),
:math:`\eta` is the learning rate, :math:`\alpha > 0` :math:`\epsilon > 0`.

Args:
params (iterable): iterable of parameters to optimize or dicts defining parameter groups
lr (float, optional): learning rate
betas (Tuple[float, float, float], optional): coefficients used for running averages (default: (0.9, 0.999, 0.9999))
alpha (float, optional): the exponential decay rate of the second moment estimates (default: 5.0)
eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
amsgrad (bool, optional): whether to use the AMSGrad variant (default: False)
"""

def __init__(
self,
params: Iterable[torch.nn.Parameter],
lr: float = 1e-3,
betas: Tuple[float, float, float] = (0.9, 0.999, 0.9999),
alpha: float = 5.0,
eps: float = 1e-8,
weight_decay: float = 0.0,
) -> None:
if lr < 0.0:
raise ValueError(f"Invalid learning rate: {lr}")
if eps < 0.0:
raise ValueError(f"Invalid epsilon value: {eps}")
for idx, beta in enumerate(betas):
if not 0.0 <= beta < 1.0:
raise ValueError(f"Invalid beta parameter at index {idx}: {beta}")
defaults = {"lr": lr, "betas": betas, "alpha": alpha, "eps": eps, "weight_decay": weight_decay}
super().__init__(params, defaults)

@torch.no_grad()
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: # type: ignore[override]
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avgs_slow = []
exp_avg_sqs = []
state_steps = []

for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(f"{self.__class__.__name__} does not support sparse gradients")
grads.append(p.grad)

state = self.state[p]
# Lazy state initialization
if len(state) == 0:
state["step"] = 0
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
state["exp_avg_slow"] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)

exp_avgs.append(state["exp_avg"])
exp_avgs_slow.append(state["exp_avg_slow"])
exp_avg_sqs.append(state["exp_avg_sq"])

# update the steps for each param group update
state["step"] += 1
# record the step after step update
state_steps.append(state["step"])

beta1, beta2, beta3 = group["betas"]
ademamix(
params_with_grad,
grads,
exp_avgs,
exp_avgs_slow,
exp_avg_sqs,
state_steps,
beta1,
beta2,
beta3,
group["alpha"],
group["lr"],
group["weight_decay"],
group["eps"],
)
return loss


def ademamix(
params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[Tensor],
exp_avgs_slow: List[Tensor],
exp_avg_sqs: List[Tensor],
state_steps: List[int],
beta1: float,
beta2: float,
beta3: float,
alpha: float,
lr: float,
weight_decay: float,
eps: float,
) -> None:
r"""Functional API that performs AdaBelief algorithm computation.
See :class:`~holocron.optim.AdaBelief` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
m1 = exp_avgs[i]
m2 = exp_avgs_slow[i]
nu = exp_avg_sqs[i]
step = state_steps[i]

bias_correction1 = 1 - beta1**step
bias_correction2 = 1 - beta2**step

if weight_decay != 0:
grad = grad.add(param, alpha=weight_decay)

# Decay the first and second moment running average coefficient
m1.mul_(beta1).add_(grad, alpha=1 - beta1)
nu.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
m2.mul_(beta3).add_(grad, alpha=1 - beta3)

denom = (nu.sqrt() / math.sqrt(bias_correction2)).add_(eps)

param.addcdiv_(m1 / bias_correction1 + alpha * m2, denom, value=-lr)
2 changes: 1 addition & 1 deletion holocron/optim/lamb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class LAMB(Optimizer):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ v_0 = 0`.

Then we correct their biases using:
Expand Down
2 changes: 1 addition & 1 deletion holocron/optim/tadam.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TAdam(Optimizer):
v_t \leftarrow \beta_2 v_{t-1} + (1 - \beta_2) (g_t - g_{t-1})

where :math:`g_t` is the gradient of :math:`\theta_t`,
:math:`\beta_1, \beta_2 \in [0, 1]^3` are the exponential average smoothing coefficients,
:math:`\beta_1, \beta_2 \in [0, 1]^2` are the exponential average smoothing coefficients,
:math:`m_0 = 0,\ v_0 = 0,\ W_0 = \frac{\beta_1}{1 - \beta_1}`;
:math:`\nu` is the degrees of freedom and :math:`d` if the number of dimensions of the parameter gradient.

Expand Down
6 changes: 5 additions & 1 deletion references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from holocron.models import classification
from holocron.models.presets import CIFAR10 as CIF10
from holocron.models.presets import IMAGENETTE
from holocron.optim import AdaBelief, AdamP
from holocron.optim import AdaBelief, AdamP, AdEMAMix
from holocron.trainer import ClassificationTrainer
from holocron.utils.data import Mixup
from holocron.utils.misc import find_image_size
Expand Down Expand Up @@ -208,6 +208,10 @@ def main(args):
optimizer = AdamP(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
elif args.opt == "adabelief":
optimizer = AdaBelief(model_params, args.lr, betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay)
elif args.opt == "ademamix":
optimizer = AdEMAMix(
model_params, args.lr, betas=(0.95, 0.99, 0.9999), eps=1e-6, weight_decay=args.weight_decay
)

log_wb = lambda metrics: wandb.log(metrics) if args.wb else None
trainer = ClassificationTrainer(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_models_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_repvgg_reparametrize():
assert mod.weight.data.shape[2:] == (3, 3)
# Check that values are still matching
with torch.no_grad():
assert torch.allclose(out, model(x), atol=1e-4)
assert torch.allclose(out, model(x), atol=1e-3)


def test_mobileone_reparametrize():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from holocron import ops


@pytest.fixture()
@pytest.fixture
def boxes():
return torch.tensor(
[[0, 0, 100, 100], [50, 50, 100, 100], [50, 50, 150, 150], [100, 100, 200, 200]], dtype=torch.float32
Expand Down
4 changes: 4 additions & 0 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ def test_adamp():

def test_adan():
_test_optimizer("Adan")


def test_ademamix():
_test_optimizer("AdEMAMix")
Loading