Skip to content

Commit

Permalink
Merge pull request #1121 from aman2304:lion_desc_v1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 692338045
  • Loading branch information
OptaxDev committed Nov 2, 2024
2 parents 66235a1 + 9f6b18f commit d4592a6
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def lion(
weight_decay: float = 1e-3,
mask: Optional[Union[Any, Callable[[base.Params], Any]]] = None,
) -> base.GradientTransformation:
"""The Lion optimizer.
r"""The Lion optimizer.
Lion is discovered by symbolic program search. Unlike most adaptive optimizers
such as AdamW, Lion only tracks momentum, making it more memory-efficient.
Expand All @@ -892,7 +892,32 @@ def lion(
AdamW. A suitable learning rate for Lion is typically 3-10x smaller than that
for AdamW, the weight decay for Lion should be in turn 3-10x larger than that
for AdamW to maintain a similar strength (lr * wd).
Let :math:`\alpha_t` represent the learning rate and :math:`\beta_1, \beta_2`,
represent the arguments ``b1`` and ``b2`` respectively. The learning rate is
indexed by :math:`t` since the learning rate may also be provided by a
schedule function. Let :math:`\lambda` be the weight decay and
:math:`\theta_t` the parameter vector at time :math:`t`.
The ``init`` function of this optimizer initializes an internal state
:math:`S_0 := (m_0) = (0)`, representing the intial estimate for the
first moment. In practice these values are stored as pytrees
containing all zeros, with the same shape as the model updates.
At step :math:`t`, the ``update`` function of this optimizer takes as
arguments the incoming gradients :math:`g_t`, the optimizer state :math:`S_t`
and the parameters :math:`\theta_t` and computes updates :math:`u_t` and
new state :math:`S_{t+1}`. Thus, for :math:`t > 0`, we have,
.. math::
\begin{align*}
c_t &\leftarrow \beta_1 \cdot m_{t-1} + (1-\beta_1) \cdot g_t \\
u_t &\leftarrow -\alpha_t \cdot \left( sign \left( c_t \right) +
\lambda \theta_{t} \right)\\
m_t &\leftarrow \beta_2 \cdot m_{t-1} + (1-\beta_2) \cdot g_t \\
S_t &\leftarrow (m_t).
\end{align*}
Examples:
>>> import optax
>>> import jax
Expand Down

0 comments on commit d4592a6

Please sign in to comment.