Skip to content

Commit

Permalink
Ensure optimizers return updates of same dtype as params.
Browse files Browse the repository at this point in the history
Fix #1038, fix #377, fix #1051

PiperOrigin-RevId: 674026550
  • Loading branch information
vroulet authored and OptaxDev committed Sep 23, 2024
1 parent 8543fe4 commit c15e795
Show file tree
Hide file tree
Showing 19 changed files with 690 additions and 476 deletions.
6 changes: 5 additions & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ years, well-cited (100+ citations), and demonstrate broad utility.
if they offer clear advantages over widely used methods.

If your algorithm doesn't meet the main package criteria, the {doc}`api/contrib`
directory is perfect for sharing innovative work.
directory is perfect for sharing innovative work. Please make sure that all
common tests (in `optax/contrib/_common_test.py` or `optax/_src/alias_test.py`)
are passed when you propose a new algorithm. These tests ensure the
interoperability of algorithms with different features of optax (such as
gradient accumulation or varying hyperparameters).


## Design Documents
Expand Down
133 changes: 111 additions & 22 deletions optax/_src/alias_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from optax._src import update
from optax.losses import _classification
from optax.schedules import _inject
from optax.transforms import _accumulation
import optax.tree_utils as otu


import scipy.optimize as scipy_optimize
from sklearn import datasets
from sklearn import linear_model
Expand Down Expand Up @@ -163,13 +165,16 @@ def step(params, state):

params = initial_params
state = opt.init(params)
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

for _ in range(10000):
params, state = step(params, state)
with self.subTest('Test that tree_map_params works'):
# A no-op change, to verify that tree map works.
state = otu.tree_map_params(opt, lambda v: v, state)

with self.subTest('Test that optimization works'):
for _ in range(10000):
params, state = step(params, state)

chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)
chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2)

@chex.all_variants
@parameterized.product(_OPTIMIZERS_UNDER_TEST)
Expand Down Expand Up @@ -210,24 +215,108 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams(
chex.assert_trees_all_close(
new_state_inject.inner_state, new_state, rtol=1e-4)

@parameterized.named_parameters([
('float32', 'float32'),
('bfloat16', 'bfloat16'),
('complex64', 'complex64'),
('None', None),
])
def test_explicit_dtype(self, dtype):
expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32
tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype)
trace_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(trace_state, 'trace').dtype)
tx = alias.adam(0.1, mu_dtype=dtype)
adam_state, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype)
tx = alias.adamw(0.1, mu_dtype=dtype)
adam_state, _, _ = tx.init(jnp.array([0.0, 0.0]))
self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype)
@parameterized.product(
params_dtype=('bfloat16', 'float32', 'complex64', None),
state_dtype=('bfloat16', 'float32', 'complex64', None),
opt_name=('sgd_mom', 'adam', 'adamw'),
)
def test_explicit_dtype(self, params_dtype, state_dtype, opt_name):
if opt_name == 'sgd_mom':
opt = alias.sgd(0.1, momentum=0.9, accumulator_dtype=state_dtype)
attribute_name = 'trace'
elif opt_name in ['adam', 'adamw']:
opt = getattr(alias, opt_name)(0.1, mu_dtype=state_dtype)
attribute_name = 'mu'
else:
raise ValueError(f'Unsupported optimizer: {opt_name}')

params_dtype = jax.dtypes.canonicalize_dtype(params_dtype)
params = jnp.array([0.0, 0.0], dtype=params_dtype)
state_has_lower_dtype = (
jnp.promote_types(params_dtype, jnp.dtype(state_dtype))
== params_dtype
)
if state_dtype is None or state_has_lower_dtype:
state = opt.init(params)

with self.subTest('Test that attribute dtype is correct'):
if state_dtype is None:
expected_dtype = params_dtype
else:
expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype)
attribute = otu.tree_get(state, attribute_name)
self.assertEqual(expected_dtype, attribute.dtype)

with self.subTest(
'Verifies that the updates keep the same type as params'
):
updates, _ = opt.update(jnp.ones_like(params), state, params)
self.assertEqual(updates.dtype, params.dtype)
else:
with self.subTest(
'Test that we forbid setting dtype s.t. updates dtype get promoted to'
' the state dtype'
):
with self.assertRaises(ValueError):
opt.init(params)

# Not testing with `without_device=True` because without_device set the
# variables to the host which appears to convert then the dtype, so we
# lose control of the dtype and the test fails.
@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(
_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers return updates of same dtype as params."""
# When debugging this test, note that operations like
# x = 0.5**jnp.asarray(1, dtype=jnp.int32)
# (appearing in e.g. optax.tree_utils.tree_bias_correction)
# are promoted (strictly) to float32 when jitted
# see https://github.com/google/jax/issues/23337
# This may end up letting updates have a dtype different from params.
# The solution is to fix the dtype of the result to the desired dtype
# (just as done in optax.tree_utils.tree_bias_correction).
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
opt = opt_factory(**opt_kwargs)
fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(
with_jit=True, without_jit=True, with_device=True, with_pmap=True
)
@parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32'))
def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks for issues like https://github.com/google-deepmind/optax/issues/377
dtype = jnp.dtype(dtype)
opt_factory = getattr(alias, opt_name)
base_opt = opt_factory(**opt_kwargs)
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
if opt_name == 'polyak_sgd':
update_kwargs = {'value': fun(params)}
else:
update_kwargs = {}
updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))

##########################
# ALGORITHM SPECIFIC TESTS
Expand Down
29 changes: 16 additions & 13 deletions optax/_src/factorized.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,23 @@ def init_fn(params):
"""Initialise the optimiser's state."""

def _init(param):
shape = param.shape
shape, dtype = param.shape, param.dtype
factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
d1, d0 = factored_dims
vr_shape = np.delete(shape, d0)
vc_shape = np.delete(shape, d1)
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros(vr_shape),
v_col=jnp.zeros(vc_shape),
v=jnp.zeros((1,)))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros(vr_shape, dtype=dtype),
v_col=jnp.zeros(vc_shape, dtype=dtype),
v=jnp.zeros((1,), dtype=dtype))
else:
return _UpdateResult(
update=jnp.zeros((1,)),
v_row=jnp.zeros((1,)),
v_col=jnp.zeros((1,)),
v=jnp.zeros(param.shape))
update=jnp.zeros((1,), dtype=dtype),
v_row=jnp.zeros((1,), dtype=dtype),
v_col=jnp.zeros((1,), dtype=dtype),
v=jnp.zeros(param.shape, dtype=dtype))

return _to_state(
jnp.zeros([], jnp.int32), jax.tree.map(_init, params))
Expand All @@ -153,13 +153,13 @@ def update_fn(grads, state, params):
raise ValueError(base.NO_PARAMS_MSG)

def _update(grad, v_row, v_col, v, param, step):
shape = param.shape
shape, dtype = param.shape, param.dtype
decay_rate_t = decay_rate_fn(step - step_offset, decay_rate)

# Scaled by factorized second moment statistics.
new_v_row = jnp.zeros((1,))
new_v_col = jnp.zeros((1,))
new_v = jnp.zeros((1,))
new_v_row = jnp.zeros((1,), dtype=dtype)
new_v_col = jnp.zeros((1,), dtype=dtype)
new_v = jnp.zeros((1,), dtype=dtype)

factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor)
if factored_dims is not None:
Expand All @@ -171,6 +171,8 @@ def _update(grad, v_row, v_col, v, param, step):
new_v_col = (
decay_rate_t * v_col +
(1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1))
new_v_row = new_v_row.astype(dtype)
new_v_col = new_v_col.astype(dtype)
reduced_d1 = d1-1 if d1 > d0 else d1
row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True)
row_factor = (new_v_row / row_col_mean) ** -0.5
Expand All @@ -182,6 +184,7 @@ def _update(grad, v_row, v_col, v, param, step):
else:
grad_sqr = numerics.abs_sq(grad) + epsilon
new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr
new_v = new_v.astype(dtype)
update = grad * (new_v)**-0.5

return _UpdateResult(update, new_v_row, new_v_col, new_v)
Expand Down
49 changes: 49 additions & 0 deletions optax/_src/factorized_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from absl.testing import parameterized

import chex
import jax
import jax.numpy as jnp

from optax._src import factorized
from optax.transforms import _accumulation


class FactorizedTest(parameterized.TestCase):
Expand All @@ -45,6 +47,53 @@ def test_scale_by_factored_rms(self):
chex.assert_tree_all_finite((params, updates, state))
chex.assert_trees_all_equal_shapes(params, updates)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_preserve_dtype(self, factorized_dims: bool, dtype: str):
"""Test that the optimizer returns updates of same dtype as params."""
dtype = jnp.dtype(dtype)
opt = factorized.scale_by_factored_rms()
fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
self.assertEqual(updates.dtype, params.dtype)

@chex.variants(with_jit=True, without_jit=True, with_device=True)
@parameterized.product(
factorized_dims=(True, False),
dtype=('bfloat16', 'float32')
)
def test_gradient_accumulation(self, factorized_dims, dtype):
"""Test that the optimizers can safely be used with optax.MultiSteps."""
# Checks if https://github.com/google-deepmind/optax/issues/377 is fixed.
dtype = jnp.dtype(dtype)
base_opt = factorized.scale_by_factored_rms()
opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4)

fun = lambda x: jnp.sum(x**2)

if factorized_dims:
# The updates are factored only for large enough parameters
# default min_dim_size_to_factor is 128 so we use 129 here.
params = jnp.ones((129, 129), dtype=dtype)
else:
params = jnp.array([1.0, 2.0], dtype=dtype)
grads = jax.grad(fun)(params)
state = self.variant(opt.init)(params)
updates, _ = self.variant(opt.update)(grads, state, params)
chex.assert_trees_all_equal(updates, jnp.zeros_like(grads))


if __name__ == '__main__':
absltest.main()
32 changes: 19 additions & 13 deletions optax/_src/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,11 +293,17 @@ def scale_by_adam(
Returns:
A `GradientTransformation` object.
Raises:
ValueError: If the selected ``mu_dtype`` induces a dtype promotion of the
dtypes of the parameters.
"""

mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
if mu_dtype is not None:
otu.tree_assert_dtype_preserved(params, mu_dtype)
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = otu.tree_zeros_like(params) # Second moment
return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)
Expand Down Expand Up @@ -716,14 +722,15 @@ def scale_by_radam(
A `GradientTransformation` object.
"""

ro_inf = 2./(1 - b2) - 1
def _radam_update(params):
ro = params[0]
mu_hat = params[1]
nu_hat = params[2]
r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro))
ro_inf = 2./(1. - b2) - 1.

def _radam_update(ro, mu_hat, nu_hat):
r = jnp.sqrt((ro - 4.)*(ro - 2.)*ro_inf/((ro_inf - 4.)*(ro_inf - 2.)*ro))
updates = jax.tree.map(
lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat)
lambda m, v: r.astype(m.dtype) * m / (jnp.sqrt(v + eps_root) + eps),
mu_hat,
nu_hat,
)
return updates

def init_fn(params):
Expand All @@ -749,7 +756,7 @@ def update_fn(updates, state, params=None):
nu_hat = otu.tree_bias_correction(nu, b2, count_inc)
updates = jax.tree.map(
lambda t, f: jnp.where(ro >= threshold, t, f),
_radam_update((ro, mu_hat, nu_hat)),
_radam_update(ro, mu_hat, nu_hat),
mu_hat,
)
return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu)
Expand Down Expand Up @@ -1050,7 +1057,7 @@ def scale_by_sm3(
"""

def zeros_for_dim(p):
return [jnp.zeros([s]) for s in p.shape]
return [jnp.zeros([s], dtype=p.dtype) for s in p.shape]

def init_fn(params):
_reject_complex(params)
Expand Down Expand Up @@ -1136,8 +1143,8 @@ def scale_by_novograd(
mu_dtype = utils.canonicalize_dtype(mu_dtype)

def init_fn(params):
mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment
nu = jax.tree.map(lambda _: 0.0, params) # Second moment
mu = otu.tree_zeros_like(params, dtype=mu_dtype)
nu = jax.tree.map(lambda p: jnp.asarray(0.0, dtype=p.dtype), params)
return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu)

def nu_addition(grads):
Expand All @@ -1147,8 +1154,7 @@ def mu_addition(grads, params, nu):
return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params

def init_nu(grads, nu):
del nu
return jax.tree.map(nu_addition, grads)
return jax.tree.map(lambda g, n: nu_addition(g).astype(n.dtype), grads, nu)

def update_nu(grads, nu):
updates = jax.tree.map(nu_addition, grads)
Expand Down
Loading

0 comments on commit c15e795

Please sign in to comment.