Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replaced seed with key in add_noise and noisy_sgd #1138

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions optax/_src/alias.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from optax._src import linesearch as _linesearch
from optax._src import transform
from optax._src import wrappers

import chex

MaskOrFn = Optional[Union[Any, Callable[[base.Params], Any]]]

Expand Down Expand Up @@ -1253,10 +1253,10 @@ def lamb(


def noisy_sgd(
key: chex.PRNGKey,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would still keep learning_rate first as in other optimizers for this function, and I would give a default value for the key for convenience, like key=jax.random.PNRGKey(0).
Sometimes one wants to benchmark quickly optimizers and it seems easier if they all have a similar signature, at least for the first argument.
On the other hand, for the transform, (add_noise) and other functions Robert pointed out, I agree that it would make sense to have the key first.

learning_rate: base.ScalarOrSchedule,
eta: float = 0.01,
gamma: float = 0.55,
seed: int = 0,
) -> base.GradientTransformation:
r"""A variant of SGD with added noise.

Expand All @@ -1282,12 +1282,12 @@ def noisy_sgd(
represents the initial variance ``eta``.

Args:
key: a PRNG key used as the random key.
learning_rate: A global scaling factor, either fixed or evolving along
iterations with a scheduler, see :func:`optax.scale_by_learning_rate`.
eta: Initial variance for the Gaussian noise added to gradients.
gamma: A parameter controlling the annealing of noise over time ``t``, the
variance decays according to ``(1+t)**(-gamma)``.
seed: Seed for the pseudo-random generation process.

Returns:
The corresponding :class:`optax.GradientTransformation`.
Expand All @@ -1297,7 +1297,8 @@ def noisy_sgd(
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> solver = optax.noisy_sgd(learning_rate=0.003)
>>> key = jax.random.key(42)
>>> solver = optax.noisy_sgd(key, learning_rate=0.003)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
Expand All @@ -1318,7 +1319,7 @@ def noisy_sgd(
Networks <https://arxiv.org/abs/1511.06807>`_, 2015
"""
return combine.chain(
transform.add_noise(eta, gamma, seed),
transform.add_noise(eta, gamma, key),
transform.scale_by_learning_rate(learning_rate),
)

Expand Down
30 changes: 27 additions & 3 deletions optax/transforms/_adding.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,42 @@ class AddNoiseState(NamedTuple):


def add_noise(
eta: float, gamma: float, seed: int
eta: float, gamma: float, key: chex.PRNGKey
) -> base.GradientTransformation:
"""Add gradient noise.

Args:
eta: Base variance of the gaussian noise added to the gradient.
gamma: Decay exponent for annealing of the variance.
seed: Seed for random number generation.
key: a PRNG key used as the random key.

Returns:
A :class:`optax.GradientTransformation` object.

Examples:
>>> import optax
>>> import jax
>>> import jax.numpy as jnp
>>> def f(x): return jnp.sum(x ** 2) # simple quadratic function
>>> key = jax.random.key(42)
>>> noise = optax.add_noise(eta=0.01, gamma=0.55, key=key)
>>> sgd = optax.scale_by_learning_rate(learning_rate=0.003)
>>> solver = optax.chain(noise, sgd)
>>> params = jnp.array([1., 2., 3.])
>>> print('Objective function: ', f(params))
Objective function: 14.0
>>> opt_state = solver.init(params)
>>> for _ in range(5):
... grad = jax.grad(f)(params)
... updates, opt_state = solver.update(grad, opt_state, params)
... params = optax.apply_updates(params, updates)
... print('Objective function: {:.2E}'.format(f(params)))
Objective function: 1.38E+01
Objective function: 1.37E+01
Objective function: 1.35E+01
Objective function: 1.33E+01
Objective function: 1.32E+01

References:
Neelakantan et al, `Adding Gradient Noise Improves Learning for Very Deep
Networks <https://arxiv.org/abs/1511.06807>`_, 2015
Expand All @@ -91,7 +115,7 @@ def add_noise(
def init_fn(params):
del params
return AddNoiseState(
count=jnp.zeros([], jnp.int32), rng_key=jax.random.PRNGKey(seed)
count=jnp.zeros([], jnp.int32), rng_key=key
)

def update_fn(updates, state, params=None): # pylint: disable=missing-docstring
Expand Down
Loading