Skip to content

Commit

Permalink
Adding new optimizers to the doc.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 604576309
  • Loading branch information
vroulet authored and OptaxDev committed Feb 6, 2024
1 parent 8ba9c02 commit 54c238c
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 4 deletions.
9 changes: 9 additions & 0 deletions docs/api/optimizers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ Optimizers

.. autosummary::
adabelief
adadelta
adafactor
adagrad
adam
Expand Down Expand Up @@ -33,6 +34,10 @@ AdaBelief
~~~~~~~~~
.. autofunction:: adabelief

AdaDelta
~~~~~~~~~
.. autofunction:: adadelta

AdaGrad
~~~~~~~
.. autofunction:: adagrad
Expand Down Expand Up @@ -109,6 +114,10 @@ RMSProp
~~~~~~~
.. autofunction:: rmsprop

RProp
~~~~~
.. autofunction:: rprop

SGD
~~~
.. autofunction:: sgd
Expand Down
14 changes: 14 additions & 0 deletions docs/api/transformations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ Transformations
per_example_layer_norm_clip
scale
ScaleState
scale_by_adadelta
ScaleByAdaDeltaState
scale_by_adam
scale_by_adamax
ScaleByAdamState
Expand All @@ -54,6 +56,8 @@ Transformations
scale_by_radam
scale_by_rms
ScaleByRmsState
scale_by_rprop
ScaleByRpropState
scale_by_rss
ScaleByRssState
scale_by_schedule
Expand Down Expand Up @@ -160,6 +164,10 @@ Transformations and states
.. autoclass:: ScaleState
:members:

.. autofunction:: scale_by_adadelta
.. autoclass:: ScaleByAdaDeltaState
:members:

.. autofunction:: scale_by_adam
.. autofunction:: scale_by_adamax
.. autoclass:: ScaleByAdamState
Expand All @@ -177,6 +185,8 @@ Transformations and states
.. autoclass:: FactoredState
:members:

.. autofunction:: scale_by_learning_rate

.. autofunction:: scale_by_lion
.. autoclass:: ScaleByLionState
:members:
Expand All @@ -197,6 +207,10 @@ Transformations and states
.. autoclass:: ScaleByRmsState
:members:

.. autofunction:: scale_by_rprop
.. autoclass:: ScaleByRpropState
:members:

.. autofunction:: scale_by_rss
.. autoclass:: ScaleByRssState
:members:
Expand Down
10 changes: 10 additions & 0 deletions optax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
from optax._src.transform import ema
from optax._src.transform import EmaState
from optax._src.transform import scale
from optax._src.transform import scale_by_adadelta
from optax._src.transform import scale_by_adam
from optax._src.transform import scale_by_adamax
from optax._src.transform import scale_by_amsgrad
Expand All @@ -114,18 +115,21 @@
from optax._src.transform import scale_by_param_block_rms
from optax._src.transform import scale_by_radam
from optax._src.transform import scale_by_rms
from optax._src.transform import scale_by_rprop
from optax._src.transform import scale_by_rss
from optax._src.transform import scale_by_schedule
from optax._src.transform import scale_by_sm3
from optax._src.transform import scale_by_stddev
from optax._src.transform import scale_by_trust_ratio
from optax._src.transform import scale_by_yogi
from optax._src.transform import ScaleByAdaDeltaState
from optax._src.transform import ScaleByAdamState
from optax._src.transform import ScaleByAmsgradState
from optax._src.transform import ScaleByBeliefState
from optax._src.transform import ScaleByLionState
from optax._src.transform import ScaleByNovogradState
from optax._src.transform import ScaleByRmsState
from optax._src.transform import ScaleByRpropState
from optax._src.transform import ScaleByRssState
from optax._src.transform import ScaleByRStdDevState
from optax._src.transform import ScaleByScheduleState
Expand Down Expand Up @@ -210,6 +214,7 @@

__all__ = (
"adabelief",
"adadelta",
"adafactor",
"adagrad",
"adam",
Expand Down Expand Up @@ -302,10 +307,12 @@
"power_iteration",
"radam",
"rmsprop",
"rprop",
"safe_int32_increment",
"safe_norm",
"safe_root_mean_squares",
"ScalarOrSchedule",
"scale_by_adadelta",
"scale_by_adam",
"scale_by_adamax",
"scale_by_amsgrad",
Expand All @@ -317,6 +324,7 @@
"scale_by_param_block_rms",
"scale_by_radam",
"scale_by_rms",
"scale_by_rprop",
"scale_by_rss",
"scale_by_schedule",
"scale_by_sm3",
Expand All @@ -325,12 +333,14 @@
"scale_by_yogi",
"scale_gradient",
"scale",
"ScaleByAdaDeltaState",
"ScaleByAdamState",
"ScaleByAmsgradState",
"ScaleByBeliefState",
"ScaleByLionState",
"ScaleByNovogradState",
"ScaleByRmsState",
"ScaleByRpropState",
"ScaleByRssState",
"ScaleByRStdDevState",
"ScaleByScheduleState",
Expand Down
3 changes: 2 additions & 1 deletion optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ def adadelta(
[Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf)
Args:
learning_rate: A fixed global scaling factor.
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
rho: A coefficient used for computing a running average of squared
gradients.
eps: Term added to the denominator to improve numerical stability.
Expand Down
6 changes: 3 additions & 3 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def update_fn(updates, state, params):
return base.GradientTransformation(init_fn, update_fn)


class ScaleByAdadelta(NamedTuple):
class ScaleByAdaDeltaState(NamedTuple):
"""State for the rescaling by Adadelta algoritm."""

e_g: base.Updates
Expand All @@ -635,7 +635,7 @@ def scale_by_adadelta(
def init_fn(params):
e_g = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared gradient]
e_x = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared update]
return ScaleByAdadelta(e_g=e_g, e_x=e_x)
return ScaleByAdaDeltaState(e_g=e_g, e_x=e_x)

def update_fn(updates, state, params=None):
del params
Expand All @@ -650,7 +650,7 @@ def update_fn(updates, state, params=None):
state.e_x,
)
e_x = update_moment(updates, state.e_x, rho, 2)
return updates, ScaleByAdadelta(e_g=e_g, e_x=e_x)
return updates, ScaleByAdaDeltaState(e_g=e_g, e_x=e_x)

return base.GradientTransformation(init_fn, update_fn)

Expand Down

0 comments on commit 54c238c

Please sign in to comment.