Skip to content

Commit

Permalink
Ensure optimizers return updates of same dtype as params.
Browse files Browse the repository at this point in the history
Fix #1038, fix #377

PiperOrigin-RevId: 674026550
  • Loading branch information
vroulet authored and OptaxDev committed Sep 13, 2024
1 parent c0e4228 commit 6ef174d
Show file tree
Hide file tree
Showing 22 changed files with 612 additions and 443 deletions.
10 changes: 10 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Tree
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_cast
tree_div
tree_dtype
tree_get
tree_get_all_with_path
tree_l1_norm
Expand Down Expand Up @@ -121,6 +123,14 @@ Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul

Tree cast
~~~~~~~~~
.. autofunction:: tree_cast

Tree dtype
~~~~~~~~~~
.. autofunction:: tree_dtype

Tree divide
~~~~~~~~~~~
.. autofunction:: tree_div
Expand Down
6 changes: 5 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ years, well-cited (100+ citations), and demonstrate broad utility.
if they offer clear advantages over widely used methods.

If your algorithm doesn't meet the main package criteria, the {doc}`api/contrib`
directory is perfect for sharing innovative work.
directory is perfect for sharing innovative work. Please make sure that all
common tests (in `optax/contrib/_common_test.py` or `optax/_src/alias_test.py`)
are passed when you propose a new algorithm. These tests ensure the
interoperability of algorithms with different features of optax (such as
gradient accumulation or varying hyperparameters).


## Design Documents
Expand Down
72 changes: 67 additions & 5 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@
from optax._src import update
from optax.losses import _classification
from optax.schedules import _inject
from optax.transforms import _accumulation
import optax.tree_utils as otu


import scipy.optimize as scipy_optimize
from sklearn import datasets
from sklearn import linear_model
Expand Down Expand Up @@ -164,13 +166,16 @@ def step(params, state):

params = initial_params
state = opt.init(params)
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

for _ in range(10000):
params, state = step(params, state)
with self.subTest('Test that tree_map_params works'):
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

with self.subTest('Test that optimization works'):
for _ in range(10000):
params, state = step(params, state)

chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)

@chex.all_variants
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
Expand Down Expand Up @@ -229,6 +234,63 @@ def test_explicit_dtype(self, dtype):
adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype)

# Not testing with `without_device=True` because without_device set the
# variables to the host which appears to convert then the dtype, so we
# lose control of the dtype and the test fails.
@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(
_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers return updates of same dtype as params."""
# When debugging this test, note that operations like
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
# are promoted (strictly) to float32 when jitted
# see https://github.com/google/jax/issues/23337
# This may end up letting updates have a dtype different from params.
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree_utils.tree_bias_correction).
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
opt = opt_factory(**opt_kwargs)
fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32'))
def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks for issues like https://github.com/google-deepmind/optax/issues/377
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
base_opt = opt_factory(**opt_kwargs)
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))

##########################
# ALGORITHM SPECIFIC TESTS
Expand Down
29 changes: 16 additions & 13 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,23 @@ def init_fn(params):
"""Initialise the optimiser's state."""

def _init(param):
shape = param.shape
shape, dtype = param.shape, param.dtype
factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
d1, d0 = factored_dims
vr_shape = np.delete(shape, d0)
vc_shape = np.delete(shape, d1)
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros(vr_shape),
v_col=jnp.zeros(vc_shape),
v=jnp.zeros((1,)))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros(vr_shape, dtype=dtype),
v_col=jnp.zeros(vc_shape, dtype=dtype),
v=jnp.zeros((1,), dtype=dtype))
else:
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros((1,)),
v_col=jnp.zeros((1,)),
v=jnp.zeros(param.shape))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype))

return _to_state(
jnp.zeros([], jnp.int32), jax.tree_util.tree_map(_init, params))
Expand All @@ -153,13 +153,13 @@ def update_fn(grads, state, params):
raise ValueError(base.NO_PARAMS_MSG)

def _update(grad, v_row, v_col, v, param, step):
shape = param.shape
shape, dtype = param.shape, param.dtype
decay_rate_t = decay_rate_fn(step - step_offset, decay_rate)

# Scaled by factorized second moment statistics.
new_v_row = jnp.zeros((1,))
new_v_col = jnp.zeros((1,))
new_v = jnp.zeros((1,))
new_v_row = jnp.zeros((1,), dtype=dtype)
new_v_col = jnp.zeros((1,), dtype=dtype)
new_v = jnp.zeros((1,), dtype=dtype)

factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
Expand All @@ -171,6 +171,8 @@ def _update(grad, v_row, v_col, v, param, step):
new_v_col = (
decay_rate_t * v_col +
(1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1))
new_v_row = new_v_row.astype(dtype)
new_v_col = new_v_col.astype(dtype)
reduced_d1 = d1-1 if d1 > d0 else d1
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
row_factor = (new_v_row / row_col_mean) ** -0.5
Expand All @@ -182,6 +184,7 @@ def _update(grad, v_row, v_col, v, param, step):
else:
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr
new_v = new_v.astype(dtype)
update = grad * (new_v)**-0.5

return _UpdateResult(update, new_v_row, new_v_col, new_v)
Expand Down
49 changes: 49 additions & 0 deletions optax/_src/factorized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax._src import factorized
from optax.transforms import _accumulation


class FactorizedTest(parameterized.TestCase):
Expand All @@ -45,6 +47,53 @@ def test_scale_by_factored_rms(self):
chex.assert_tree_all_finite((params, updates, state))
chex.assert_trees_all_equal_shapes(params, updates)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, factorized_dims: bool, dtype: str):
"""Test that the optimizer returns updates of same dtype as params."""
dtype = jnp.dtype(dtype)
opt = factorized.scale_by_factored_rms()
fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_gradient_accumulation(self, factorized_dims, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks if https://github.com/google-deepmind/optax/issues/377 is fixed.
dtype = jnp.dtype(dtype)
base_opt = factorized.scale_by_factored_rms()
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions optax/_src/numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def safe_increment(count: chex.Numeric) -> chex.Numeric:
counter stays at ``max_val``.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> optax.safe_increment(jnp.asarray(1, dtype=jnp.int32))
Array(2, dtype=int32)
Expand Down
26 changes: 13 additions & 13 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,14 +717,15 @@ def scale_by_radam(
A `GradientTransformation` object.
"""

ro_inf = 2./(1 - b2) - 1
def _radam_update(params):
ro = params[0]
mu_hat = params[1]
nu_hat = params[2]
r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro))
ro_inf = 2./(1. - b2) - 1.

def _radam_update(ro, mu_hat, nu_hat):
r = jnp.sqrt((ro - 4.)*(ro - 2.)*ro_inf/((ro_inf - 4.)*(ro_inf - 2.)*ro))
updates = jtu.tree_map(
lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
lambda m, v: r.astype(m.dtype) * m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_hat,
)
return updates

def init_fn(params):
Expand All @@ -750,7 +751,7 @@ def update_fn(updates, state, params=None):
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.tree_util.tree_map(
lambda t, f: jnp.where(ro >= threshold, t, f),
_radam_update((ro, mu_hat, nu_hat)),
_radam_update(ro, mu_hat, nu_hat),
mu_hat,
)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Expand Down Expand Up @@ -1051,7 +1052,7 @@ def scale_by_sm3(
"""

def zeros_for_dim(p):
return [jnp.zeros([s]) for s in p.shape]
return [jnp.zeros([s], dtype=p.dtype) for s in p.shape]

def init_fn(params):
_reject_complex(params)
Expand Down Expand Up @@ -1137,8 +1138,8 @@ def scale_by_novograd(
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = jtu.tree_map(lambda _: 0.0, params) # Second moment
mu = otu.tree_zeros_like(params, dtype=mu_dtype)
nu = jtu.tree_map(lambda p: jnp.asarray(0.0, dtype=p.dtype), params)
return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def nu_addition(grads):
Expand All @@ -1148,8 +1149,7 @@ def mu_addition(grads, params, nu):
return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params

def init_nu(grads, nu):
del nu
return jtu.tree_map(nu_addition, grads)
return jtu.tree_map(lambda g, n: nu_addition(g).astype(n.dtype), grads, nu)

def update_nu(grads, nu):
updates = jtu.tree_map(nu_addition, grads)
Expand Down
6 changes: 4 additions & 2 deletions optax/contrib/_cocob.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from optax._src import base
from optax._src import combine
from optax._src import transform
import optax.tree_utils as otu


class COCOBState(NamedTuple):
Expand Down Expand Up @@ -58,8 +59,9 @@ def scale_by_cocob(
"""

def init_fn(params):
init_adapt = jtu.tree_map(lambda p: jnp.zeros(p.shape), params)
init_scale = jtu.tree_map(lambda p: eps * jnp.ones(p.shape), params)
init_adapt = otu.tree_zeros_like(params)
init_scale = otu.tree_ones_like(params)
init_scale = otu.tree_scalar_mul(eps, init_scale)
return COCOBState(
init_particles=params,
cumulative_gradients=init_adapt,
Expand Down
Loading

0 comments on commit 6ef174d

Please sign in to comment.