diff --git a/docs/api/optimizers.rst b/docs/api/optimizers.rst index d88e1596..2808b914 100644 --- a/docs/api/optimizers.rst +++ b/docs/api/optimizers.rst @@ -5,6 +5,7 @@ Optimizers .. autosummary:: adabelief + adadelta adafactor adagrad adam @@ -33,6 +34,10 @@ AdaBelief ~~~~~~~~~ .. autofunction:: adabelief +AdaDelta +~~~~~~~~~ +.. autofunction:: adadelta + AdaGrad ~~~~~~~ .. autofunction:: adagrad @@ -109,6 +114,10 @@ RMSProp ~~~~~~~ .. autofunction:: rmsprop +RProp +~~~~~ +.. autofunction:: rprop + SGD ~~~ .. autofunction:: sgd diff --git a/docs/api/transformations.rst b/docs/api/transformations.rst index 731ae794..5ca5289f 100644 --- a/docs/api/transformations.rst +++ b/docs/api/transformations.rst @@ -34,6 +34,8 @@ Transformations per_example_layer_norm_clip scale ScaleState + scale_by_adadelta + ScaleByAdaDeltaState scale_by_adam scale_by_adamax ScaleByAdamState @@ -54,6 +56,8 @@ Transformations scale_by_radam scale_by_rms ScaleByRmsState + scale_by_rprop + ScaleByRpropState scale_by_rss ScaleByRssState scale_by_schedule @@ -160,6 +164,10 @@ Transformations and states .. autoclass:: ScaleState :members: +.. autofunction:: scale_by_adadelta +.. autoclass:: ScaleByAdaDeltaState + :members: + .. autofunction:: scale_by_adam .. autofunction:: scale_by_adamax .. autoclass:: ScaleByAdamState @@ -177,6 +185,8 @@ Transformations and states .. autoclass:: FactoredState :members: +.. autofunction:: scale_by_learning_rate + .. autofunction:: scale_by_lion .. autoclass:: ScaleByLionState :members: @@ -197,6 +207,10 @@ Transformations and states .. autoclass:: ScaleByRmsState :members: +.. autofunction:: scale_by_rprop +.. autoclass:: ScaleByRpropState + :members: + .. autofunction:: scale_by_rss .. autoclass:: ScaleByRssState :members: diff --git a/optax/__init__.py b/optax/__init__.py index 84282de4..4de16566 100644 --- a/optax/__init__.py +++ b/optax/__init__.py @@ -101,6 +101,7 @@ from optax._src.transform import ema from optax._src.transform import EmaState from optax._src.transform import scale +from optax._src.transform import scale_by_adadelta from optax._src.transform import scale_by_adam from optax._src.transform import scale_by_adamax from optax._src.transform import scale_by_amsgrad @@ -114,18 +115,21 @@ from optax._src.transform import scale_by_param_block_rms from optax._src.transform import scale_by_radam from optax._src.transform import scale_by_rms +from optax._src.transform import scale_by_rprop from optax._src.transform import scale_by_rss from optax._src.transform import scale_by_schedule from optax._src.transform import scale_by_sm3 from optax._src.transform import scale_by_stddev from optax._src.transform import scale_by_trust_ratio from optax._src.transform import scale_by_yogi +from optax._src.transform import ScaleByAdaDeltaState from optax._src.transform import ScaleByAdamState from optax._src.transform import ScaleByAmsgradState from optax._src.transform import ScaleByBeliefState from optax._src.transform import ScaleByLionState from optax._src.transform import ScaleByNovogradState from optax._src.transform import ScaleByRmsState +from optax._src.transform import ScaleByRpropState from optax._src.transform import ScaleByRssState from optax._src.transform import ScaleByRStdDevState from optax._src.transform import ScaleByScheduleState @@ -210,6 +214,7 @@ __all__ = ( "adabelief", + "adadelta", "adafactor", "adagrad", "adam", @@ -302,10 +307,12 @@ "power_iteration", "radam", "rmsprop", + "rprop", "safe_int32_increment", "safe_norm", "safe_root_mean_squares", "ScalarOrSchedule", + "scale_by_adadelta", "scale_by_adam", "scale_by_adamax", "scale_by_amsgrad", @@ -317,6 +324,7 @@ "scale_by_param_block_rms", "scale_by_radam", "scale_by_rms", + "scale_by_rprop", "scale_by_rss", "scale_by_schedule", "scale_by_sm3", @@ -325,12 +333,14 @@ "scale_by_yogi", "scale_gradient", "scale", + "ScaleByAdaDeltaState", "ScaleByAdamState", "ScaleByAmsgradState", "ScaleByBeliefState", "ScaleByLionState", "ScaleByNovogradState", "ScaleByRmsState", + "ScaleByRpropState", "ScaleByRssState", "ScaleByRStdDevState", "ScaleByScheduleState", diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 12988f0c..bb0d9636 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -125,7 +125,8 @@ def adadelta( [Matthew D. Zeiler, 2012](https://arxiv.org/pdf/1212.5701.pdf) Args: - learning_rate: A fixed global scaling factor. + learning_rate: A global scaling factor, either fixed or evolving along + iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. rho: A coefficient used for computing a running average of squared gradients. eps: Term added to the denominator to improve numerical stability. diff --git a/optax/_src/transform.py b/optax/_src/transform.py index e8abb785..7fc81179 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -608,7 +608,7 @@ def update_fn(updates, state, params): return base.GradientTransformation(init_fn, update_fn) -class ScaleByAdadelta(NamedTuple): +class ScaleByAdaDeltaState(NamedTuple): """State for the rescaling by Adadelta algoritm.""" e_g: base.Updates @@ -635,7 +635,7 @@ def scale_by_adadelta( def init_fn(params): e_g = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared gradient] e_x = jax.tree_util.tree_map(jnp.zeros_like, params) # E[squared update] - return ScaleByAdadelta(e_g=e_g, e_x=e_x) + return ScaleByAdaDeltaState(e_g=e_g, e_x=e_x) def update_fn(updates, state, params=None): del params @@ -650,7 +650,7 @@ def update_fn(updates, state, params=None): state.e_x, ) e_x = update_moment(updates, state.e_x, rho, 2) - return updates, ScaleByAdadelta(e_g=e_g, e_x=e_x) + return updates, ScaleByAdaDeltaState(e_g=e_g, e_x=e_x) return base.GradientTransformation(init_fn, update_fn)