diff --git a/optax/_src/alias.py b/optax/_src/alias.py index 08346fa1..8be68b7d 100644 --- a/optax/_src/alias.py +++ b/optax/_src/alias.py @@ -26,7 +26,7 @@ from optax._src import linesearch as _linesearch from optax._src import transform from optax._src import wrappers - +import chex MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]] @@ -1253,10 +1253,10 @@ def lamb( def noisy_sgd( + key: chex.PRNGKey, learning_rate: base.ScalarOrSchedule, eta: float = 0.01, gamma: float = 0.55, - seed: int = 0, ) -> base.GradientTransformation: r"""A variant of SGD with added noise. @@ -1282,12 +1282,12 @@ def noisy_sgd( represents the initial variance ``eta``. Args: + key: a PRNG key used as the random key. learning_rate: A global scaling factor, either fixed or evolving along iterations with a scheduler, see :func:`optax.scale_by_learning_rate`. eta: Initial variance for the Gaussian noise added to gradients. gamma: A parameter controlling the annealing of noise over time ``t``, the variance decays according to ``(1+t)**(-gamma)``. - seed: Seed for the pseudo-random generation process. Returns: The corresponding :class:`optax.GradientTransformation`. @@ -1297,7 +1297,8 @@ def noisy_sgd( >>> import jax >>> import jax.numpy as jnp >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function - >>> solver = optax.noisy_sgd(learning_rate=0.003) + >>> key = jax.random.key(42) + >>> solver = optax.noisy_sgd(key, learning_rate=0.003) >>> params = jnp.array([1., 2., 3.]) >>> print('Objective function: ', f(params)) Objective function: 14.0 @@ -1318,7 +1319,7 @@ def noisy_sgd( Networks `_, 2015 """ return combine.chain( - transform.add_noise(eta, gamma, seed), + transform.add_noise(eta, gamma, key), transform.scale_by_learning_rate(learning_rate), ) diff --git a/optax/transforms/_adding.py b/optax/transforms/_adding.py index cfc5c93f..4578ce53 100644 --- a/optax/transforms/_adding.py +++ b/optax/transforms/_adding.py @@ -71,18 +71,42 @@ class AddNoiseState(NamedTuple): def add_noise( - eta: float, gamma: float, seed: int + eta: float, gamma: float, key: chex.PRNGKey ) -> base.GradientTransformation: """Add gradient noise. Args: eta: Base variance of the gaussian noise added to the gradient. gamma: Decay exponent for annealing of the variance. - seed: Seed for random number generation. + key: a PRNG key used as the random key. Returns: A :class:`optax.GradientTransformation` object. + Examples: + >>> import optax + >>> import jax + >>> import jax.numpy as jnp + >>> def f(x): return jnp.sum(x ** 2) # simple quadratic function + >>> key = jax.random.key(42) + >>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key) + >>> sgd = optax.scale_by_learning_rate(learning_rate=0.003) + >>> solver = optax.chain(noise, sgd) + >>> params = jnp.array([1., 2., 3.]) + >>> print('Objective function: ', f(params)) + Objective function: 14.0 + >>> opt_state = solver.init(params) + >>> for _ in range(5): + ... grad = jax.grad(f)(params) + ... updates, opt_state = solver.update(grad, opt_state, params) + ... params = optax.apply_updates(params, updates) + ... print('Objective function: {:.2E}'.format(f(params))) + Objective function: 1.38E+01 + Objective function: 1.37E+01 + Objective function: 1.35E+01 + Objective function: 1.33E+01 + Objective function: 1.32E+01 + References: Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep Networks `_, 2015 @@ -91,7 +115,7 @@ def add_noise( def init_fn(params): del params return AddNoiseState( - count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed) + count=jnp.zeros([], jnp.int32), rng_key=key ) def update_fn(updates, state, params=None): # pylint: disable=missing-docstring