From 6ae2426d6bcb6174e44ce7cbd67496faec363f1a Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Sat, 2 Nov 2024 17:35:42 -0400 Subject: [PATCH] Add Nesterov momentum to AdaBelief optimizer. --- optax/_src/alias.py | 21 +++++++++++++++++++-- optax/_src/transform.py | 14 ++++++++++++-- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/optax/_src/alias.py b/optax/_src/alias.py index b3a6807f..3c85eef4 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -36,7 +36,10 @@ def adabelief( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, - eps_root: float = 1e-16) -> base.GradientTransformation: + eps_root: float = 1e-16, + *, + nesterov: bool = False, +) -> base.GradientTransformation: r"""The AdaBelief optimizer. AdaBelief is an adaptive learning rate optimizer that focuses on fast @@ -74,6 +77,13 @@ def adabelief( S_t &\leftarrow (m_t, s_t). \end{align*} + With the keyword argument `nesterov=True`, the optimizer uses Nesterov + momentum, replacing the above :math:`\hat{m}_t` with + + .. math:: + \hat{m}_t \leftarrow + \beta_1 m_t / {(1-\beta_1^{t+1})} + (1 - \beta_1) g_t / {(1-\beta_1^t)}. + Examples: >>> import optax >>> import jax @@ -107,12 +117,19 @@ def adabelief( eps_root: Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero. + nesterov: Whether to use Nesterov momentum. Returns: The corresponding `GradientTransformation`. """ return combine.chain( - transform.scale_by_belief(b1=b1, b2=b2, eps=eps, eps_root=eps_root), + transform.scale_by_belief( + b1=b1, + b2=b2, + eps=eps, + eps_root=eps_root, + nesterov=nesterov, + ), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/_src/transform.py b/optax/_src/transform.py index 394306b0..ab629fbf 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -685,7 +685,9 @@ def scale_by_belief( b1: float = 0.9, b2: float = 0.999, eps: float = 1e-16, - eps_root: float = 1e-16 + eps_root: float = 1e-16, + *, + nesterov: bool = False, ) -> base.GradientTransformation: """Rescale updates according to the AdaBelief algorithm. @@ -699,6 +701,7 @@ def scale_by_belief( eps_root: Term added to the second moment of the prediction error to improve numerical stability. If backpropagating gradients through the gradient transformation (e.g. for meta-learning), this must be non-zero. + nesterov: Whether to use Nesterov momentum. Returns: A `GradientTransformation` object. @@ -717,7 +720,14 @@ def update_fn(updates, state, params=None): nu = otu.tree_update_moment_per_elem_norm(prediction_error, state.nu, b2, 2) nu = jax.tree.map(lambda v: v + eps_root, nu) count_inc = numerics.safe_increment(state.count) - mu_hat = otu.tree_bias_correction(mu, b1, count_inc) + if nesterov: + mu_hat = jax.tree.map( + lambda m, g: b1 * m + (1 - b1) * g, + otu.tree_bias_correction( + mu, b1, numerics.safe_increment(count_inc)), + otu.tree_bias_correction(updates, b1, count_inc)) + else: + mu_hat = otu.tree_bias_correction(mu, b1, count_inc) nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jax.tree.map( lambda m, v: None if m is None else m / (jnp.sqrt(v) + eps),