From d06bb483da19ebc2fa13c050282212ead51e7907 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Thu, 12 Sep 2024 15:07:15 -0700 Subject: [PATCH] Ensure optimizers return updates of same dtype as params. Fix #1038, fix #377, fix #1051 PiperOrigin-RevId: 674026550 --- docs/api/utilities.rst | 11 + docs/development.md | 6 +- optax/_src/alias_test.py | 133 ++++++++++-- optax/_src/factorized.py | 29 +-- optax/_src/factorized_test.py | 49 +++++ optax/_src/transform.py | 32 +-- optax/contrib/_cocob.py | 6 +- optax/contrib/_common_test.py | 254 +++++++++++++++++++++-- optax/contrib/_dadapt_adamw.py | 26 ++- optax/contrib/_dog.py | 18 +- optax/contrib/_mechanic.py | 25 ++- optax/contrib/_mechanic_test.py | 88 +------- optax/contrib/_momo.py | 59 +++--- optax/contrib/_prodigy.py | 27 ++- optax/contrib/_reduce_on_plateau.py | 24 ++- optax/contrib/_reduce_on_plateau_test.py | 20 +- optax/contrib/_sam_test.py | 158 +++----------- optax/contrib/_schedule_free.py | 65 ++++-- optax/contrib/_schedule_free_test.py | 141 +++++-------- optax/transforms/_accumulation.py | 6 + optax/tree_utils/__init__.py | 2 + optax/tree_utils/_casting.py | 230 +++++++++++++++++++- optax/tree_utils/_casting_test.py | 110 +++++++++- 23 files changed, 1034 insertions(+), 485 deletions(-) diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index c792944fd..3e6f3fc86 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -92,7 +92,10 @@ Tree NamedTupleKey tree_add tree_add_scalar_mul + tree_cast + tree_check_no_dtype_promotion tree_div + tree_dtype tree_get tree_get_all_with_path tree_l1_norm @@ -122,6 +125,14 @@ Tree add and scalar multiply ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_add_scalar_mul +Tree cast +~~~~~~~~~ +.. autofunction:: tree_cast + +Tree dtype +~~~~~~~~~~ +.. autofunction:: tree_dtype + Tree divide ~~~~~~~~~~~ .. autofunction:: tree_div diff --git a/docs/development.md b/docs/development.md index 581eddc27..9e930ef0c 100644 --- a/docs/development.md +++ b/docs/development.md @@ -53,7 +53,11 @@ years, well-cited (100+ citations), and demonstrate broad utility. if they offer clear advantages over widely used methods. If your algorithm doesn't meet the main package criteria, the {doc}`api/contrib` -directory is perfect for sharing innovative work. +directory is perfect for sharing innovative work. Please make sure that all +common tests (in `optax/contrib/_common_test.py` or `optax/_src/alias_test.py`) +are passed when you propose a new algorithm. These tests ensure the +interoperability of algorithms with different features of optax (such as +gradient accumulation or varying hyperparameters). ## Design Documents diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index a2444c689..d0e2e35b5 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -34,8 +34,10 @@ from optax._src import update from optax.losses import _classification from optax.schedules import _inject +from optax.transforms import _accumulation import optax.tree_utils as otu + import scipy.optimize as scipy_optimize from sklearn import datasets from sklearn import linear_model @@ -163,13 +165,16 @@ def step(params, state): params = initial_params state = opt.init(params) - # A no-op change, to verify that tree map works. - state = otu.tree_map_params(opt, lambda v: v, state) - for _ in range(10000): - params, state = step(params, state) + with self.subTest('Test that tree_map_params works'): + # A no-op change, to verify that tree map works. + state = otu.tree_map_params(opt, lambda v: v, state) + + with self.subTest('Test that optimization works'): + for _ in range(10000): + params, state = step(params, state) - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) @chex.all_variants @parameterized.product(_OPTIMIZERS_UNDER_TEST) @@ -210,24 +215,108 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( chex.assert_trees_all_close( new_state_inject.inner_state, new_state, rtol=1e-4) - @parameterized.named_parameters([ - ('float32', 'float32'), - ('bfloat16', 'bfloat16'), - ('complex64', 'complex64'), - ('None', None), - ]) - def test_explicit_dtype(self, dtype): - expected_dtype = jax.dtypes.canonicalize_dtype(dtype) # None -> float32 - tx = alias.sgd(0.1, momentum=0.9, accumulator_dtype=dtype) - trace_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, getattr(trace_state, 'trace').dtype) - tx = alias.adam(0.1, mu_dtype=dtype) - adam_state, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype) - tx = alias.adamw(0.1, mu_dtype=dtype) - adam_state, _, _ = tx.init(jnp.array([0.0, 0.0])) - self.assertEqual(expected_dtype, getattr(adam_state, 'mu').dtype) + @parameterized.product( + params_dtype=('bfloat16', 'float32', 'complex64', None), + state_dtype=('bfloat16', 'float32', 'complex64', None), + opt_name=('sgd_mom', 'adam', 'adamw'), + ) + def test_explicit_dtype(self, params_dtype, state_dtype, opt_name): + if opt_name == 'sgd_mom': + opt = alias.sgd(0.1, momentum=0.9, accumulator_dtype=state_dtype) + attribute_name = 'trace' + elif opt_name in ['adam', 'adamw']: + opt = getattr(alias, opt_name)(0.1, mu_dtype=state_dtype) + attribute_name = 'mu' + else: + raise ValueError(f'Unsupported optimizer: {opt_name}') + + params_dtype = jax.dtypes.canonicalize_dtype(params_dtype) + params = jnp.array([0.0, 0.0], dtype=params_dtype) + state_has_lower_dtype = ( + jnp.promote_types(params_dtype, jnp.dtype(state_dtype)) + == params_dtype + ) + if state_dtype is None or state_has_lower_dtype: + state = opt.init(params) + + with self.subTest('Test that attribute dtype is correct'): + if state_dtype is None: + expected_dtype = params_dtype + else: + expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype) + attribute = otu.tree_get(state, attribute_name) + self.assertEqual(expected_dtype, attribute.dtype) + + with self.subTest( + 'Verifies that the updates keep the same type as params' + ): + updates, _ = opt.update(jnp.ones_like(params), state, params) + self.assertEqual(updates.dtype, params.dtype) + else: + with self.subTest( + 'Test that we forbid setting dtype s.t. updates dtype get promoted to' + ' the state dtype' + ): + with self.assertRaises(ValueError): + opt.init(params) + + # Not testing with `without_device=True` because without_device set the + # variables to the host which appears to convert then the dtype, so we + # lose control of the dtype and the test fails. + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product( + _OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') + ) + def test_preserve_dtype(self, opt_name, opt_kwargs, dtype): + """Test that the optimizers return updates of same dtype as params.""" + # When debugging this test, note that operations like + # x = 0.5**jnp.asarray(1, dtype=jnp.int32) + # (appearing in e.g. optax.tree_utils.tree_bias_correction) + # are promoted (strictly) to float32 when jitted + # see https://github.com/google/jax/issues/23337 + # This may end up letting updates have a dtype different from params. + # The solution is to fix the dtype of the result to the desired dtype + # (just as done in optax.tree_utils.tree_bias_correction). + dtype = jnp.dtype(dtype) + opt_factory = getattr(alias, opt_name) + opt = opt_factory(**opt_kwargs) + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + grads = jax.grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name == 'polyak_sgd': + update_kwargs = {'value': fun(params)} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + self.assertEqual(updates.dtype, params.dtype) + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product(_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32')) + def test_gradient_accumulation(self, opt_name, opt_kwargs, dtype): + """Test that the optimizers can safely be used with optax.MultiSteps.""" + # Checks for issues like https://github.com/google-deepmind/optax/issues/377 + dtype = jnp.dtype(dtype) + opt_factory = getattr(alias, opt_name) + base_opt = opt_factory(**opt_kwargs) + opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4) + + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + grads = jax.grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name == 'polyak_sgd': + update_kwargs = {'value': fun(params)} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) ########################## # ALGORITHM SPECIFIC TESTS diff --git a/optax/_src/factorized.py b/optax/_src/factorized.py index 7379c9ac8..2f040e66d 100644 --- a/optax/_src/factorized.py +++ b/optax/_src/factorized.py @@ -126,23 +126,23 @@ def init_fn(params): """Initialise the optimiser's state.""" def _init(param): - shape = param.shape + shape, dtype = param.shape, param.dtype factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) if factored_dims is not None: d1, d0 = factored_dims vr_shape = np.delete(shape, d0) vc_shape = np.delete(shape, d1) return _UpdateResult( - update=jnp.zeros((1,)), - v_row=jnp.zeros(vr_shape), - v_col=jnp.zeros(vc_shape), - v=jnp.zeros((1,))) + update=jnp.zeros((1,), dtype=dtype), + v_row=jnp.zeros(vr_shape, dtype=dtype), + v_col=jnp.zeros(vc_shape, dtype=dtype), + v=jnp.zeros((1,), dtype=dtype)) else: return _UpdateResult( - update=jnp.zeros((1,)), - v_row=jnp.zeros((1,)), - v_col=jnp.zeros((1,)), - v=jnp.zeros(param.shape)) + update=jnp.zeros((1,), dtype=dtype), + v_row=jnp.zeros((1,), dtype=dtype), + v_col=jnp.zeros((1,), dtype=dtype), + v=jnp.zeros(param.shape, dtype=dtype)) return _to_state( jnp.zeros([], jnp.int32), jax.tree.map(_init, params)) @@ -153,13 +153,13 @@ def update_fn(grads, state, params): raise ValueError(base.NO_PARAMS_MSG) def _update(grad, v_row, v_col, v, param, step): - shape = param.shape + shape, dtype = param.shape, param.dtype decay_rate_t = decay_rate_fn(step - step_offset, decay_rate) # Scaled by factorized second moment statistics. - new_v_row = jnp.zeros((1,)) - new_v_col = jnp.zeros((1,)) - new_v = jnp.zeros((1,)) + new_v_row = jnp.zeros((1,), dtype=dtype) + new_v_col = jnp.zeros((1,), dtype=dtype) + new_v = jnp.zeros((1,), dtype=dtype) factored_dims = _factored_dims(shape, factored, min_dim_size_to_factor) if factored_dims is not None: @@ -171,6 +171,8 @@ def _update(grad, v_row, v_col, v, param, step): new_v_col = ( decay_rate_t * v_col + (1. - decay_rate_t) * jnp.mean(grad_sqr, axis=d1)) + new_v_row = new_v_row.astype(dtype) + new_v_col = new_v_col.astype(dtype) reduced_d1 = d1-1 if d1 > d0 else d1 row_col_mean = jnp.mean(new_v_row, axis=reduced_d1, keepdims=True) row_factor = (new_v_row / row_col_mean) ** -0.5 @@ -182,6 +184,7 @@ def _update(grad, v_row, v_col, v, param, step): else: grad_sqr = numerics.abs_sq(grad) + epsilon new_v = decay_rate_t * v + (1. - decay_rate_t) * grad_sqr + new_v = new_v.astype(dtype) update = grad * (new_v)**-0.5 return _UpdateResult(update, new_v_row, new_v_col, new_v) diff --git a/optax/_src/factorized_test.py b/optax/_src/factorized_test.py index 71f596d90..34ae3a1ff 100644 --- a/optax/_src/factorized_test.py +++ b/optax/_src/factorized_test.py @@ -18,9 +18,11 @@ from absl.testing import parameterized import chex +import jax import jax.numpy as jnp from optax._src import factorized +from optax.transforms import _accumulation class FactorizedTest(parameterized.TestCase): @@ -45,6 +47,53 @@ def test_scale_by_factored_rms(self): chex.assert_tree_all_finite((params, updates, state)) chex.assert_trees_all_equal_shapes(params, updates) + @chex.variants(with_jit=True, without_jit=True, with_device=True) + @parameterized.product( + factorized_dims=(True, False), + dtype=('bfloat16', 'float32') + ) + def test_preserve_dtype(self, factorized_dims: bool, dtype: str): + """Test that the optimizer returns updates of same dtype as params.""" + dtype = jnp.dtype(dtype) + opt = factorized.scale_by_factored_rms() + fun = lambda x: jnp.sum(x**2) + + if factorized_dims: + # The updates are factored only for large enough parameters + # default min_dim_size_to_factor is 128 so we use 129 here. + params = jnp.ones((129, 129), dtype=dtype) + else: + params = jnp.array([1.0, 2.0], dtype=dtype) + grads = jax.grad(fun)(params) + state = self.variant(opt.init)(params) + updates, _ = self.variant(opt.update)(grads, state, params) + self.assertEqual(updates.dtype, params.dtype) + + @chex.variants(with_jit=True, without_jit=True, with_device=True) + @parameterized.product( + factorized_dims=(True, False), + dtype=('bfloat16', 'float32') + ) + def test_gradient_accumulation(self, factorized_dims, dtype): + """Test that the optimizers can safely be used with optax.MultiSteps.""" + # Checks if https://github.com/google-deepmind/optax/issues/377 is fixed. + dtype = jnp.dtype(dtype) + base_opt = factorized.scale_by_factored_rms() + opt = _accumulation.MultiSteps(base_opt, every_k_schedule=4) + + fun = lambda x: jnp.sum(x**2) + + if factorized_dims: + # The updates are factored only for large enough parameters + # default min_dim_size_to_factor is 128 so we use 129 here. + params = jnp.ones((129, 129), dtype=dtype) + else: + params = jnp.array([1.0, 2.0], dtype=dtype) + grads = jax.grad(fun)(params) + state = self.variant(opt.init)(params) + updates, _ = self.variant(opt.update)(grads, state, params) + chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) + if __name__ == '__main__': absltest.main() diff --git a/optax/_src/transform.py b/optax/_src/transform.py index e167fd6d5..fcb1f3859 100644 --- a/optax/_src/transform.py +++ b/optax/_src/transform.py @@ -293,11 +293,17 @@ def scale_by_adam( Returns: A `GradientTransformation` object. + + Raises: + ValueError: If the selected ``mu_dtype`` induces a dtype promotion of the + dtypes of the parameters. """ mu_dtype = utils.canonicalize_dtype(mu_dtype) def init_fn(params): + if mu_dtype is not None: + otu.tree_assert_dtype_preserved(params, mu_dtype) mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment nu = otu.tree_zeros_like(params) # Second moment return ScaleByAdamState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) @@ -716,14 +722,15 @@ def scale_by_radam( A `GradientTransformation` object. """ - ro_inf = 2./(1 - b2) - 1 - def _radam_update(params): - ro = params[0] - mu_hat = params[1] - nu_hat = params[2] - r = jnp.sqrt((ro - 4)*(ro - 2)*ro_inf/((ro_inf - 4)*(ro_inf - 2)*ro)) + ro_inf = 2./(1. - b2) - 1. + + def _radam_update(ro, mu_hat, nu_hat): + r = jnp.sqrt((ro - 4.)*(ro - 2.)*ro_inf/((ro_inf - 4.)*(ro_inf - 2.)*ro)) updates = jax.tree.map( - lambda m, v: r*m / (jnp.sqrt(v + eps_root) + eps), mu_hat, nu_hat) + lambda m, v: r.astype(m.dtype) * m / (jnp.sqrt(v + eps_root) + eps), + mu_hat, + nu_hat, + ) return updates def init_fn(params): @@ -749,7 +756,7 @@ def update_fn(updates, state, params=None): nu_hat = otu.tree_bias_correction(nu, b2, count_inc) updates = jax.tree.map( lambda t, f: jnp.where(ro >= threshold, t, f), - _radam_update((ro, mu_hat, nu_hat)), + _radam_update(ro, mu_hat, nu_hat), mu_hat, ) return updates, ScaleByAdamState(count=count_inc, mu=mu, nu=nu) @@ -1050,7 +1057,7 @@ def scale_by_sm3( """ def zeros_for_dim(p): - return [jnp.zeros([s]) for s in p.shape] + return [jnp.zeros([s], dtype=p.dtype) for s in p.shape] def init_fn(params): _reject_complex(params) @@ -1136,8 +1143,8 @@ def scale_by_novograd( mu_dtype = utils.canonicalize_dtype(mu_dtype) def init_fn(params): - mu = otu.tree_zeros_like(params, dtype=mu_dtype) # First moment - nu = jax.tree.map(lambda _: 0.0, params) # Second moment + mu = otu.tree_zeros_like(params, dtype=mu_dtype) + nu = jax.tree.map(lambda p: jnp.asarray(0.0, dtype=p.dtype), params) return ScaleByNovogradState(count=jnp.zeros([], jnp.int32), mu=mu, nu=nu) def nu_addition(grads): @@ -1147,8 +1154,7 @@ def mu_addition(grads, params, nu): return grads / (jnp.sqrt(nu + eps_root) + eps) + weight_decay * params def init_nu(grads, nu): - del nu - return jax.tree.map(nu_addition, grads) + return jax.tree.map(lambda g, n: nu_addition(g).astype(n.dtype), grads, nu) def update_nu(grads, nu): updates = jax.tree.map(nu_addition, grads) diff --git a/optax/contrib/_cocob.py b/optax/contrib/_cocob.py index edf13a021..7fcb77a27 100644 --- a/optax/contrib/_cocob.py +++ b/optax/contrib/_cocob.py @@ -25,6 +25,7 @@ from optax._src import base from optax._src import combine from optax._src import transform +import optax.tree_utils as otu class COCOBState(NamedTuple): @@ -58,8 +59,9 @@ def scale_by_cocob( """ def init_fn(params): - init_adapt = jax.tree.map(lambda p: jnp.zeros(p.shape), params) - init_scale = jax.tree.map(lambda p: eps * jnp.ones(p.shape), params) + init_adapt = otu.tree_zeros_like(params) + init_scale = otu.tree_ones_like(params) + init_scale = otu.tree_scalar_mul(eps, init_scale) return COCOBState( init_particles=params, cumulative_gradients=init_adapt, diff --git a/optax/contrib/_common_test.py b/optax/contrib/_common_test.py index c3db5823e..f044d7a11 100644 --- a/optax/contrib/_common_test.py +++ b/optax/contrib/_common_test.py @@ -15,24 +15,29 @@ """Common tests for contributed optimizers. Additional specific tests are implemented in additional files -(see e.g. sam_test) """ +import functools +import inspect + from absl.testing import absltest from absl.testing import parameterized import chex import jax import jax.numpy as jnp from optax import contrib +from optax._src import alias +from optax._src import combine from optax._src import numerics from optax._src import update from optax.schedules import _inject +from optax.transforms import _accumulation from optax.tree_utils import _state_utils # Testing contributions coded as GradientTransformations -_OPTIMIZERS_UNDER_TEST = ( +_MAIN_OPTIMIZERS_UNDER_TEST = [ dict(opt_name='acprop', opt_kwargs=dict(learning_rate=1e-3)), - dict(opt_name='cocob', opt_kwargs=dict(alpha=100.0, eps=1e-8)), + dict(opt_name='cocob', opt_kwargs={}), dict(opt_name='cocob', opt_kwargs=dict(weight_decay=1e-2)), dict(opt_name='dadapt_adamw', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='dog', opt_kwargs=dict(learning_rate=1.0)), @@ -40,7 +45,97 @@ dict(opt_name='momo', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='momo_adam', opt_kwargs=dict(learning_rate=1e-1)), dict(opt_name='prodigy', opt_kwargs=dict(learning_rate=1e-1)), + dict( + opt_name='schedule_free_sgd', + opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), + ), + dict( + opt_name='schedule_free_adamw', + opt_kwargs=dict(learning_rate=1e-2, warmup_steps=5000), + ), +] +for optimizer in _MAIN_OPTIMIZERS_UNDER_TEST: + optimizer['wrapper_name'] = None + optimizer['wrapper_kwargs'] = None + +# Testing contributions coded as wrappers +# (just with sgd as we just want the behavior of the wrapper) +_MAIN_OPTIMIZERS_UNDER_TEST += [ + dict( + opt_name='sgd', + opt_kwargs=dict(learning_rate=1e-1), + wrapper_name='mechanize', + wrapper_kwargs=dict(weight_decay=0.0), + ), + dict( + opt_name='sgd', + opt_kwargs=dict(learning_rate=1e-2), + wrapper_name='schedule_free', + wrapper_kwargs=dict(learning_rate=1e-2), + ), + dict( + opt_name='sgd', + opt_kwargs=dict(learning_rate=1e-3), + wrapper_name='reduce_on_plateau', + wrapper_kwargs={}, + ), +] + +# Adding here instantiations of wrappers with any base optimizer +_BASE_OPTIMIZERS = [ + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), + dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), + dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), + dict( + opt_name='optimistic_gradient_descent', + opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), + ), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), + dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), + dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), +] +# TODO(harshm): make LARS and Fromage work with mechanic. +_OTHER_OPTIMIZERS_UNDER_TEST = [ + dict( + opt_name=base_opt['opt_name'], + opt_kwargs=base_opt['opt_kwargs'], + wrapper_name='mechanize', + wrapper_kwargs=dict(weight_decay=0.0), + ) + for base_opt in _BASE_OPTIMIZERS +] + +_ALL_OPTIMIZERS_UNDER_TEST = tuple( + _MAIN_OPTIMIZERS_UNDER_TEST + _OTHER_OPTIMIZERS_UNDER_TEST ) +_MAIN_OPTIMIZERS_UNDER_TEST = tuple(_MAIN_OPTIMIZERS_UNDER_TEST) + + +def _get_opt_factory(opt_name): + """Get optimizer factory.""" + if hasattr(contrib, opt_name): + return getattr(contrib, opt_name) + if hasattr(alias, opt_name): + return getattr(alias, opt_name) + raise ValueError(f'Unknown optimizer: {opt_name}') + + +def _wrap_opt(opt, wrapper_name, wrapper_kwargs): + if wrapper_name == 'reduce_on_plateau': + return combine.chain(opt, contrib.reduce_on_plateau(**wrapper_kwargs)) + else: + return getattr(contrib, wrapper_name)(opt, **wrapper_kwargs) def _setup_parabola(dtype): @@ -75,18 +170,32 @@ def get_updates(params): class ContribTest(chex.TestCase): @parameterized.product( - _OPTIMIZERS_UNDER_TEST, + _ALL_OPTIMIZERS_UNDER_TEST, target=(_setup_parabola, _setup_rosenbrock), - dtype=(jnp.float32,), + dtype=('float32',), ) - def test_optimizers(self, opt_name, opt_kwargs, target, dtype): - opt = getattr(contrib, opt_name)(**opt_kwargs) + def test_optimizers( + self, + opt_name, + opt_kwargs, + wrapper_name, + wrapper_kwargs, + target, + dtype, + ): + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) initial_params, final_params, get_updates = target(dtype) @jax.jit def step(params, state): value, updates = get_updates(params) - if opt_name in ['momo', 'momo_adam']: + if ( + opt_name in ['momo', 'momo_adam'] + or wrapper_name == 'reduce_on_plateau' + ): update_kwargs = {'value': value} else: update_kwargs = {} @@ -96,31 +205,63 @@ def step(params, state): params = initial_params state = opt.init(params) - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) + with self.subTest('Test that tree_map_params works'): + # A no-op change, to verify that tree map works. + state = _state_utils.tree_map_params(opt, lambda v: v, state) - def f(params_state, _): - return step(*params_state), None + with self.subTest('Test that optimization works'): - (params, _), _ = jax.lax.scan(f, (params, state), length=30_000) + def f(params_state, _): + return step(*params_state), None - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) + (params, state), _ = jax.lax.scan(f, (params, state), length=30_000) + + if ( + opt_name in ['schedule_free_sgd', 'schedule_free_adamw'] + or wrapper_name == 'schedule_free' + ): + params = contrib.schedule_free_eval_params(state, params) + chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) @chex.all_variants - @parameterized.product(_OPTIMIZERS_UNDER_TEST) + @parameterized.product(_MAIN_OPTIMIZERS_UNDER_TEST) def test_optimizers_can_be_wrapped_in_inject_hyperparams( - self, opt_name, opt_kwargs + self, opt_name, opt_kwargs, wrapper_name=None, wrapper_kwargs=None ): """Checks that optimizers can be wrapped in inject_hyperparams.""" # See also https://github.com/deepmind/optax/issues/412. - opt_factory = getattr(contrib, opt_name) - opt = opt_factory(**opt_kwargs) - opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) + # When debugging this, make sure that options like weight decay or not + # are checked by asserting wehter such a value is None or not (see e.g. the + # logic in schedule_free_adamw). Some hyperparameters may not be supported + # by inject_hyperparams (e.g. warmup_steps). In that case (if you're sure + # you can ignore such hyperparameter), add the exception below. + if wrapper_name == 'reduce_on_plateau': + # TODO(vroulet): discuss adding support for reduce_on_plateau + # so removing all assertions in its definition + self.skipTest('reduce_on_plateau is not supported by inject_hyperparams.') + if wrapper_name is None: + factory = _get_opt_factory(opt_name) + hparams = opt_kwargs + else: + base_opt = _get_opt_factory(opt_name)(**opt_kwargs) + factory = getattr(contrib, wrapper_name) + factory = functools.partial(factory, base_opt) + hparams = wrapper_kwargs + opt = factory(**hparams) + + # Add here the hyperparameters that cannot be injected with + # inject_hyperparams. + static_args = [] + for uninjectable_hparam in ['warmup_steps', 'num_betas']: + if uninjectable_hparam in inspect.signature(factory).parameters.keys(): + static_args.append(uninjectable_hparam) + static_args = tuple(static_args) + opt_inject = _inject.inject_hyperparams(factory, static_args)(**hparams) params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] - if opt_name in ['momo', 'momo_adam']: + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': update_kwargs = {'value': jnp.array(1.0)} else: update_kwargs = {} @@ -136,12 +277,81 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( ) with self.subTest('Equality of updates.'): - chex.assert_trees_all_close(updates_inject, updates, rtol=1e-4) + chex.assert_trees_all_close(updates_inject, updates, rtol=1e-5) with self.subTest('Equality of new optimizer states.'): chex.assert_trees_all_close( - new_state_inject.inner_state, new_state, rtol=1e-4 + new_state_inject.inner_state, new_state, rtol=1e-5, atol=1e-5 ) + # Not testing with `without_device=True` because without_device set the + # variables to the host which appears to convert then the dtype, so we + # lose control of the dtype and the test fails. + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product( + _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') + ) + def test_preserve_dtype( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers return updates of same dtype as params.""" + # When debugging this test, note that operations like + # x = 0.5**jnp.asarray(1, dtype=jnp.int32) + # (appearing in e.g. optax.tree_utils.tree_bias_correction) + # are promoted (strictly) to float32 when jitted + # see https://github.com/google/jax/issues/23337 + # This may end up letting updates have a dtype different from params. + # The solution is to fix the dtype of the result to the desired dtype + # (just as done in optax.tree_utils.tree_bias_correction). + # Otherwise, just make sure that all variables defined in the optimizer have + # the same dtype as the parameters. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': + update_kwargs = {'value': value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + self.assertEqual(updates.dtype, params.dtype) + + @chex.variants( + with_jit=True, without_jit=True, with_device=True, with_pmap=True + ) + @parameterized.product( + _MAIN_OPTIMIZERS_UNDER_TEST, dtype=('bfloat16', 'float32') + ) + def test_gradient_accumulation( + self, opt_name, opt_kwargs, dtype, wrapper_name=None, wrapper_kwargs=None + ): + """Test that the optimizers can safely be used with optax.MultiSteps.""" + # Checks for issues like https://github.com/google-deepmind/optax/issues/377 + # Should pass as long as test_preserve_dtype passes. + dtype = jnp.dtype(dtype) + opt = _get_opt_factory(opt_name)(**opt_kwargs) + if wrapper_name is not None: + opt = _wrap_opt(opt, wrapper_name, wrapper_kwargs) + opt = _accumulation.MultiSteps(opt, every_k_schedule=4) + + fun = lambda x: jnp.sum(x**2) + + params = jnp.array([1.0, 2.0], dtype=dtype) + value, grads = jax.value_and_grad(fun)(params) + state = self.variant(opt.init)(params) + if opt_name in ['momo', 'momo_adam'] or wrapper_name == 'reduce_on_plateau': + update_kwargs = {'value': value} + else: + update_kwargs = {} + updates, _ = self.variant(opt.update)(grads, state, params, **update_kwargs) + chex.assert_trees_all_equal(updates, jnp.zeros_like(grads)) + if __name__ == '__main__': absltest.main() diff --git a/optax/contrib/_dadapt_adamw.py b/optax/contrib/_dadapt_adamw.py index 6a297415b..914b48d73 100644 --- a/optax/contrib/_dadapt_adamw.py +++ b/optax/contrib/_dadapt_adamw.py @@ -22,9 +22,9 @@ import chex import jax import jax.numpy as jnp -from optax import tree_utils from optax._src import base from optax._src import numerics +import optax.tree_utils as otu class DAdaptAdamWState(NamedTuple): @@ -72,11 +72,15 @@ def dadapt_adamw( """ def init_fn(params: base.Params) -> DAdaptAdamWState: - exp_avg = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - exp_avg_sq = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - grad_sum = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - estim_lr = jnp.asarray(estim_lr0, jnp.float32) - numerator_weighted = jnp.zeros([], jnp.float32) + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + exp_avg = otu.tree_zeros_like(params) + exp_avg_sq = otu.tree_zeros_like(params) + grad_sum = otu.tree_zeros_like(params) + estim_lr = jnp.asarray(estim_lr0, dtype=params_dtype) + numerator_weighted = jnp.zeros([], dtype=params_dtype) count = jnp.zeros([], jnp.int32) return DAdaptAdamWState( exp_avg, exp_avg_sq, grad_sum, estim_lr, numerator_weighted, count @@ -95,12 +99,14 @@ def update_fn( sched = learning_rate(count) if callable(learning_rate) else learning_rate grad_sum = state.grad_sum numerator_weighted = state.numerator_weighted - bc = ((1 - beta2 ** (count + 1)) ** 0.5) / (1 - beta1 ** (count + 1)) + count_inc = numerics.safe_increment(count) + bc = ((1 - beta2 ** count_inc) ** 0.5) / (1 - beta1 ** count_inc) dlr = state.estim_lr * sched * bc + dlr = dlr.astype(numerator_weighted.dtype) s_weighted = jax.tree.map( lambda sk, eas: sk / (jnp.sqrt(eas) + eps), grad_sum, state.exp_avg_sq ) - numerator_acum = tree_utils.tree_vdot(updates, s_weighted) + numerator_acum = otu.tree_vdot(updates, s_weighted) exp_avg = jax.tree.map( lambda ea, g: beta1 * ea + (1 - beta1) * dlr * g, state.exp_avg, updates ) @@ -112,7 +118,7 @@ def update_fn( grad_sum = jax.tree.map( lambda sk, g: sb2 * sk + (1 - sb2) * dlr * g, grad_sum, updates ) - grad_sum_l1 = tree_utils.tree_sum(jax.tree.map(jnp.abs, grad_sum)) + grad_sum_l1 = otu.tree_sum(jax.tree.map(jnp.abs, grad_sum)) numerator_weighted = ( sb2 * numerator_weighted + (1 - sb2) * dlr * numerator_acum ) @@ -130,7 +136,7 @@ def update_fn( grad_sum, estim_lr, numerator_weighted, - numerics.safe_increment(count), + count_inc, ) return p_update, new_state diff --git a/optax/contrib/_dog.py b/optax/contrib/_dog.py index c8249ab8a..b3967141f 100644 --- a/optax/contrib/_dog.py +++ b/optax/contrib/_dog.py @@ -27,10 +27,10 @@ import chex import jax import jax.numpy as jnp -from optax import tree_utils as otu from optax._src import base from optax._src import combine from optax._src import transform +import optax.tree_utils as otu class DoGState(NamedTuple): @@ -89,11 +89,15 @@ def scale_by_dog( """ def init_fn(params: base.Params) -> DoGState: + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') return DoGState( first_step=jnp.asarray(True), init_params=otu.tree_zeros_like(params), - estim_dist=jnp.asarray(0.0), - sum_sq_norm_grads=jnp.asarray(0.0), + estim_dist=jnp.asarray(0.0, dtype=params_dtype), + sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), ) def update_fn( @@ -252,14 +256,18 @@ def scale_by_dowg( """ def init_fn(params: base.Params) -> DoWGState: + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') if init_estim_sq_dist is None: init_estim_sq_dist_ = eps else: init_estim_sq_dist_ = init_estim_sq_dist return DoWGState( init_params=params, - estim_sq_dist=jnp.asarray(init_estim_sq_dist_), - weighted_sq_norm_grads=jnp.asarray(0.0), + estim_sq_dist=jnp.asarray(init_estim_sq_dist_, dtype=params_dtype), + weighted_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), ) def update_fn( diff --git a/optax/contrib/_mechanic.py b/optax/contrib/_mechanic.py index faf876f91..86b12d043 100644 --- a/optax/contrib/_mechanic.py +++ b/optax/contrib/_mechanic.py @@ -32,9 +32,9 @@ import jax import jax.numpy as jnp -from optax import tree_utils from optax._src import base from optax._src import numerics +import optax.tree_utils as otu class MechanicState(NamedTuple): @@ -108,10 +108,14 @@ def mechanize( def init_fn(params: base.Params) -> MechanicState: x0 = params - r = jnp.zeros([num_betas,], jnp.float32) - v = jnp.zeros([num_betas,], jnp.float32) - m = jnp.zeros([num_betas,], jnp.float32) - s = jnp.ones([num_betas,], jnp.float32) * s_init + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + r = jnp.zeros([num_betas,], dtype=params_dtype) + v = jnp.zeros([num_betas,], dtype=params_dtype) + m = jnp.zeros([num_betas,], dtype=params_dtype) + s = jnp.ones([num_betas,], dtype=params_dtype) * s_init return MechanicState( base_optimizer_state=base_optimizer.init(params), count=jnp.zeros([], jnp.int32), @@ -142,8 +146,8 @@ def update_fn( # Add weight decay to raw gradients, note that this is othogonal to any # weight decay applied to inner_optimizer updates. s_sum = jnp.sum(state.s) - grad_norm = tree_utils.tree_l2_norm(updates) - param_norm = tree_utils.tree_l2_norm(params) + grad_norm = otu.tree_l2_norm(updates) + param_norm = otu.tree_l2_norm(params) def add_weight_decay(gi, pi): return gi + weight_decay * s_sum * grad_norm / (param_norm + eps) * pi @@ -167,12 +171,15 @@ def add_weight_decay(gi, pi): ) # Now we are ready to run the actual Mechanic algorithm. - h = tree_utils.tree_vdot(updates, delta_prev) + h = otu.tree_vdot(updates, delta_prev) # This clipping was not part of the original paper but we introduced it # a little later. clipped_h = jax.lax.clamp(-state.m, jnp.ones_like(state.m) * h, state.m) - betas = jnp.array([1.0 - 0.1**betai for betai in range(1, num_betas + 1)]) + betas = jnp.array( + [1.0 - 0.1**betai for betai in range(1, num_betas + 1)], + dtype=state.s.dtype, + ) m = jnp.maximum(betas * state.m, jnp.abs(h) + eps) v = (betas**2) * state.v + h**2 diff --git a/optax/contrib/_mechanic_test.py b/optax/contrib/_mechanic_test.py index 9049685e3..7066950ee 100644 --- a/optax/contrib/_mechanic_test.py +++ b/optax/contrib/_mechanic_test.py @@ -12,79 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for `mechanic.py`.""" +"""Specific tests for `mechanic.py`, see `common_test.py` for usual tests.""" from typing import NamedTuple from absl.testing import absltest -from absl.testing import parameterized import chex import jax import jax.numpy as jnp import numpy as np -from optax._src import alias from optax._src import base -from optax._src import numerics from optax._src import update from optax.contrib import _mechanic from optax.tree_utils import _state_utils -# TODO(harshm): make LARS and Fromage work with mechanic. -_OPTIMIZERS_UNDER_TEST = ( - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), - dict( - opt_name='lion', opt_kwargs=dict(learning_rate=1.0, b1=0.99), - ), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), - dict( - opt_name='optimistic_gradient_descent', - opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), - ), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), -) - - -def _setup_parabola(dtype): - """Quadratic function as an optimization target.""" - initial_params = jnp.array([-1.0, 10.0, 1.0], dtype=dtype) - final_params = jnp.array([1.0, -1.0, 1.0], dtype=dtype) - - @jax.grad - def get_updates(params): - return jnp.sum(numerics.abs_sq(params - final_params)) - - return initial_params, final_params, get_updates - - -def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 - - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) - - @jax.grad - def get_updates(params): - return (numerics.abs_sq(a - params[0]) + - b * numerics.abs_sq(params[1] - params[0]**2)) - - return initial_params, final_params, get_updates - - class OptimizerTestState(NamedTuple): """Inner optimizer state for the Mechanic tests.""" aggregate_grads: base.Params @@ -161,34 +103,6 @@ def test_mechanized(self): chex.assert_trees_all_close(final_params, params) chex.assert_tree_all_finite((final_params, final_state)) - @parameterized.product( - _OPTIMIZERS_UNDER_TEST, - target=(_setup_parabola, _setup_rosenbrock), - dtype=(jnp.float32,), - ) - def test_optimization(self, opt_name, opt_kwargs, target, dtype): - - opt = getattr(alias, opt_name)(**opt_kwargs) - opt = _mechanic.mechanize(opt, weight_decay=0.0) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - updates = get_updates(params) - updates, state = opt.update(updates, state, params) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) - - for _ in range(25000): - params, state = step(params, state) - - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) - if __name__ == '__main__': absltest.main() diff --git a/optax/contrib/_momo.py b/optax/contrib/_momo.py index 8281a8b44..f3465f42c 100644 --- a/optax/contrib/_momo.py +++ b/optax/contrib/_momo.py @@ -26,9 +26,9 @@ import jax from jax import lax import jax.numpy as jnp -from optax import tree_utils from optax._src import base from optax._src import numerics +import optax.tree_utils as otu class MomoState(NamedTuple): @@ -104,10 +104,14 @@ def momo( """ def init_fn(params: base.Params) -> MomoState: - exp_avg = jax.tree.map(lambda p: jnp.zeros(p.shape), params) - barf = jnp.zeros([], jnp.float32) - gamma = jnp.zeros([], jnp.float32) - init_lb = jnp.array(lower_bound, jnp.float32) + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + exp_avg = otu.tree_zeros_like(params) + barf = jnp.zeros([], dtype=params_dtype) + gamma = jnp.zeros([], dtype=params_dtype) + init_lb = jnp.array(lower_bound, dtype=params_dtype) count = jnp.zeros([], jnp.int32) return MomoState(exp_avg, barf, gamma, init_lb, count) @@ -127,14 +131,14 @@ def update_fn( Use ``jax.value_and_grad`` for this.""") count = state.count # initialize at first gradient, and loss - bt = lax.cond(count == 0, lambda: 0.0, lambda: beta) + bt = jnp.where(count == 0, 0.0, beta) barf = bt * state.barf + (1 - bt) * value exp_avg = jax.tree.map( lambda ea, g: bt * ea + (1 - bt) * g, state.exp_avg, updates ) - gamma = bt * state.gamma + (1 - bt) * tree_utils.tree_vdot(updates, params) - exp_avg_norm = tree_utils.tree_l2_norm(exp_avg, squared=True) - iprod = tree_utils.tree_vdot(exp_avg, params) + gamma = bt * state.gamma + (1 - bt) * otu.tree_vdot(updates, params) + exp_avg_norm = otu.tree_l2_norm(exp_avg, squared=True) + iprod = otu.tree_vdot(exp_avg, params) alpha = learning_rate(count) if callable(learning_rate) else learning_rate # Reset lower bound if adapt_lower_bound: @@ -152,7 +156,7 @@ def update_fn( (1 + alpha * weight_decay) * (barf - this_lb - gamma) + iprod, 0.0 ) / (exp_avg_norm) # if denom is zero, take no step - t1 = lax.cond(exp_avg_norm <= jnp.finfo(float).eps, lambda: 0.0, lambda: t1) + t1 = jnp.where(exp_avg_norm <= jnp.finfo(float).eps, 0.0, t1) tau = jnp.minimum(alpha, t1) p_update = jax.tree.map( lambda ea, p: -(alpha * weight_decay) / (1 + alpha * weight_decay) * p @@ -257,11 +261,15 @@ def momo_adam( """ def init_fn(params: base.Params) -> MomoAdamState: - exp_avg = jax.tree.map(lambda p: jnp.zeros(p.shape), params) - exp_avg_sq = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - barf = jnp.zeros([], jnp.float32) - gamma = jnp.zeros([], jnp.float32) - init_lb = jnp.array(lower_bound, jnp.float32) + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + exp_avg = otu.tree_zeros_like(params) + exp_avg_sq = otu.tree_zeros_like(params) + barf = jnp.zeros([], dtype=params_dtype) + gamma = jnp.zeros([], dtype=params_dtype) + init_lb = jnp.array(lower_bound, dtype=params_dtype) count = jnp.zeros([], jnp.int32) return MomoAdamState(exp_avg, exp_avg_sq, barf, gamma, init_lb, count) @@ -280,6 +288,7 @@ def update_fn( raise ValueError("""You need to pass the latest loss value to Momo. Use ``jax.value_and_grad`` for this.""") count = state.count + count_inc = numerics.safe_increment(count) barf = b1 * state.barf + (1 - b1) * value exp_avg = jax.tree.map( lambda ea, g: b1 * ea + (1 - b1) * g, state.exp_avg, updates @@ -289,25 +298,25 @@ def update_fn( state.exp_avg_sq, updates, ) - bc2 = 1 - b2 ** (count + 1) + bc2 = jnp.asarray(1 - b2 ** count_inc, dtype=barf.dtype) precond = jax.tree.map(lambda eas: eps + jnp.sqrt(eas / bc2), exp_avg_sq) exp_avg_weighted = jax.tree.map( lambda ea, prec: ea / prec, exp_avg, precond ) - exp_avg_norm = tree_utils.tree_vdot(exp_avg, exp_avg_weighted) - gamma = b1 * state.gamma + (1 - b1) * tree_utils.tree_vdot(updates, params) - iprod = tree_utils.tree_vdot(exp_avg, params) + exp_avg_norm = otu.tree_vdot(exp_avg, exp_avg_weighted) + gamma = b1 * state.gamma + (1 - b1) * otu.tree_vdot(updates, params) + iprod = otu.tree_vdot(exp_avg, params) alpha = learning_rate(count) if callable(learning_rate) else learning_rate - bc1 = 1 - b1 ** (count + 1) + bc1 = jnp.asarray(1 - b1 ** count_inc, dtype=barf.dtype) # Reset lower bound if adapt_lower_bound: cap = (1 + alpha * weight_decay) * (barf - gamma) + iprod - this_lb = lax.cond( + this_lb = jnp.where( cap < (1 + alpha * weight_decay) * bc1 * state.lb, - lambda: jnp.maximum( + jnp.maximum( cap / (2 * bc1 * (1 + alpha * weight_decay)), lower_bound ), - lambda: state.lb, + state.lb, ) else: this_lb = state.lb @@ -315,7 +324,7 @@ def update_fn( (1 + alpha * weight_decay) * (barf - bc1 * this_lb - gamma) + iprod, 0.0 ) / (exp_avg_norm) # if denom is zero, take no step - t1 = lax.cond(exp_avg_norm <= jnp.finfo(float).eps, lambda: 0.0, lambda: t1) + t1 = jnp.where(exp_avg_norm <= jnp.finfo(float).eps, 0.0, t1) tau = jnp.minimum(alpha / bc1, t1) p_update = jax.tree.map( lambda ea, prec, p: -(alpha * weight_decay) @@ -337,7 +346,7 @@ def update_fn( barf=barf, gamma=gamma, lb=new_lb, - count=numerics.safe_increment(count), + count=count_inc, ) return p_update, new_state diff --git a/optax/contrib/_prodigy.py b/optax/contrib/_prodigy.py index 3f616fdc2..fd43d9e0c 100644 --- a/optax/contrib/_prodigy.py +++ b/optax/contrib/_prodigy.py @@ -23,9 +23,9 @@ import chex import jax import jax.numpy as jnp -from optax import tree_utils from optax._src import base from optax._src import numerics +import optax.tree_utils as otu class ProdigyState(NamedTuple): @@ -87,12 +87,16 @@ def prodigy( beta3 = beta2**0.5 def init_fn(params: base.Params) -> ProdigyState: - exp_avg = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - exp_avg_sq = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) - grad_sum = jax.tree.map(lambda p: jnp.zeros(p.shape, jnp.float32), params) + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + exp_avg = otu.tree_zeros_like(params) + exp_avg_sq = otu.tree_zeros_like(params) + grad_sum = otu.tree_zeros_like(params) params0 = params - estim_lr = jnp.asarray(estim_lr0, jnp.float32) - numerator_weighted = jnp.zeros((), jnp.float32) + estim_lr = jnp.asarray(estim_lr0, dtype=params_dtype) + numerator_weighted = jnp.zeros((), dtype=params_dtype) count = jnp.zeros((), jnp.int32) return ProdigyState( exp_avg, @@ -112,16 +116,17 @@ def update_fn( if params is None: raise ValueError(base.NO_PARAMS_MSG) count = state.count + count_inc = numerics.safe_increment(count) sched = learning_rate(count) if callable(learning_rate) else learning_rate grad_sum = state.grad_sum params0 = state.params0 estim_lr = state.estim_lr numerator_weighted = state.numerator_weighted - bc = ((1 - beta2 ** (count + 1)) ** 0.5) / (1 - beta1 ** (count + 1)) - dlr = estim_lr * sched * bc + bc = ((1 - beta2 ** count_inc) ** 0.5) / (1 - beta1 ** count_inc) + dlr = jnp.asarray(estim_lr * sched * bc, dtype=estim_lr.dtype) dg = jax.tree.map(lambda g: estim_lr * g, updates) param_diff = jax.tree.map(lambda p0, p: p0 - p, params0, params) - numerator_acum = tree_utils.tree_vdot(updates, param_diff) + numerator_acum = otu.tree_vdot(updates, param_diff) exp_avg = jax.tree.map( lambda ea, dgk: beta1 * ea + (1 - beta1) * dgk, state.exp_avg, dg ) @@ -140,7 +145,7 @@ def update_fn( ) numerator_weighted = beta3 * numerator_weighted numerator_weighted += (estim_lr / estim_lr0) * dlr * numerator_acum - denominator = tree_utils.tree_sum(jax.tree.map(jnp.abs, grad_sum)) + denominator = otu.tree_sum(jax.tree.map(jnp.abs, grad_sum)) lr_estimate = estim_lr_coef * numerator_weighted / denominator estim_lr = jnp.maximum(state.estim_lr, lr_estimate) p_update = jax.tree.map( @@ -157,7 +162,7 @@ def update_fn( params0, estim_lr, numerator_weighted, - numerics.safe_increment(count), + count_inc, ) return p_update, new_state diff --git a/optax/contrib/_reduce_on_plateau.py b/optax/contrib/_reduce_on_plateau.py index a1d19e9a0..5f710ed8d 100644 --- a/optax/contrib/_reduce_on_plateau.py +++ b/optax/contrib/_reduce_on_plateau.py @@ -26,17 +26,18 @@ import jax.numpy as jnp from optax._src import base from optax._src import numerics +import optax.tree_utils as otu class ReduceLROnPlateauState(NamedTuple): """State for the ReduceLROnPlateau callback.""" - scale: chex.Array # shape=(), dtype=jnp.float32 - best_value: chex.Array # shape=(), dtype=jnp.float32 + scale: chex.Array + best_value: chex.Array plateau_count: chex.Array # shape=(), dtype=jnp.int32 cooldown_count: chex.Array # shape=(), dtype=jnp.int32 count: chex.Array # shape=(), dtype=jnp.int32 - avg_value: chex.Array # shape=(), dtype=jnp.float32 + avg_value: chex.Array def reduce_on_plateau( @@ -96,14 +97,17 @@ def reduce_on_plateau( ) def init_fn(params) -> ReduceLROnPlateauState: - del params + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, "lowest") return ReduceLROnPlateauState( - best_value=jnp.asarray(float("inf"), dtype=jnp.float32), + best_value=jnp.asarray(float("inf")), plateau_count=jnp.asarray(0, jnp.int32), - scale=jnp.asarray(1.0, dtype=jnp.float32), + scale=jnp.asarray(1.0, dtype=params_dtype), cooldown_count=jnp.asarray(0, jnp.int32), count=jnp.asarray(0, jnp.int32), - avg_value=jnp.asarray(0.0, jnp.float32), + avg_value=jnp.asarray(0.0), ) def _update_scale(state): @@ -116,7 +120,7 @@ def _update_scale(state): has_improved, avg_value, state.best_value ) curr_plateau_count = jnp.where( - has_improved, 0, numerics.safe_int32_increment(state.plateau_count) + has_improved, 0, numerics.safe_increment(state.plateau_count) ) # We're in cooldown, so reduce the counter and ignore any bad epochs @@ -154,7 +158,7 @@ def not_in_cooldown(): scale=new_scale, cooldown_count=new_cooldown_count, count=jnp.asarray(0, dtype=jnp.int32), - avg_value=jnp.asarray(0.0, dtype=jnp.float32), + avg_value=jnp.asarray(0.0), ) return new_state @@ -169,7 +173,7 @@ def update_fn( del params, extra_args count = state.count - new_count = numerics.safe_int32_increment(count) + new_count = numerics.safe_increment(count) new_avg_value = ( count * state.avg_value + jnp.astype(value, state.avg_value.dtype) ) / new_count diff --git a/optax/contrib/_reduce_on_plateau_test.py b/optax/contrib/_reduce_on_plateau_test.py index c96fdb85e..ad95fbe3e 100644 --- a/optax/contrib/_reduce_on_plateau_test.py +++ b/optax/contrib/_reduce_on_plateau_test.py @@ -57,6 +57,7 @@ def test_learning_rate_reduced_after_cooldown_period_is_over( state = self.transform.init(self.updates['params']) # Wait until patience runs out + updates = self.updates for _ in range(self.patience + 1): updates, state = self.transform.update( updates=self.updates, state=state, value=jnp.asarray(1.0, dtype=float) @@ -91,12 +92,12 @@ def test_learning_rate_is_not_reduced(self, enable_x64): # State with positive plateau_count state = _reduce_on_plateau.ReduceLROnPlateauState( - best_value=jnp.array(1.0, dtype=jnp.float32), + best_value=jnp.array(1.0), plateau_count=jnp.array(3, dtype=jnp.int32), - scale=jnp.array(0.1, dtype=jnp.float32), + scale=jnp.array(0.1), cooldown_count=jnp.array(0, dtype=jnp.int32), count=jnp.array(0, dtype=jnp.int32), - avg_value=jnp.array(0.0, dtype=jnp.float32), + avg_value=jnp.array(0.0), ) # Update with better value @@ -119,12 +120,12 @@ def test_learning_rate_not_reduced_during_cooldown(self, enable_x64): # State with positive cooldown_count state = _reduce_on_plateau.ReduceLROnPlateauState( - best_value=jnp.array(1.0, dtype=jnp.float32), + best_value=jnp.array(1.0), plateau_count=jnp.array(0, dtype=jnp.int32), - scale=jnp.array(0.1, dtype=jnp.float32), + scale=jnp.array(0.1), cooldown_count=jnp.array(3, dtype=jnp.int32), count=jnp.array(0, dtype=jnp.int32), - avg_value=jnp.array(0.0, dtype=jnp.float32), + avg_value=jnp.array(0.0), ) # Update with worse value @@ -151,15 +152,16 @@ def test_learning_rate_not_reduced_after_end_scale_is_reached( # State with scale == min_scale state = _reduce_on_plateau.ReduceLROnPlateauState( - best_value=jnp.array(1.0, dtype=jnp.float32), + best_value=jnp.array(1.0), plateau_count=jnp.array(0, dtype=jnp.int32), - scale=jnp.array(0.01, dtype=jnp.float32), + scale=jnp.array(0.01), cooldown_count=jnp.array(0, dtype=jnp.int32), count=jnp.array(0, dtype=jnp.int32), - avg_value=jnp.array(0.0, dtype=jnp.float32), + avg_value=jnp.array(0.0), ) # Wait until patience runs out + updates = self.updates for _ in range(self.patience + 1): updates, state = self.transform.update( updates=self.updates, state=state, value=0.1, diff --git a/optax/contrib/_sam_test.py b/optax/contrib/_sam_test.py index b5cb2446c..2a9798223 100644 --- a/optax/contrib/_sam_test.py +++ b/optax/contrib/_sam_test.py @@ -14,62 +14,25 @@ # ============================================================================== """Tests for `sam.py`.""" -from typing import NamedTuple - from absl.testing import absltest from absl.testing import parameterized import chex import jax import jax.numpy as jnp from optax._src import alias -from optax._src import base from optax._src import combine from optax._src import numerics from optax._src import update from optax.contrib import _sam from optax.tree_utils import _state_utils - -# TODO(harshm): make LARS and Fromage work with SAM. -_OPTIMIZERS_UNDER_TEST = ( - dict(opt_name='sgd', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamax', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='adamaxw', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='amsgrad', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='lamb', opt_kwargs=dict(learning_rate=1.0)), - dict( - opt_name='lion', - opt_kwargs=dict(learning_rate=1.0, b1=0.99), - ), - dict(opt_name='noisy_sgd', opt_kwargs=dict(learning_rate=1.0, eta=1e-4)), - dict(opt_name='novograd', opt_kwargs=dict(learning_rate=1.0)), - dict( - opt_name='optimistic_gradient_descent', - opt_kwargs=dict(learning_rate=1.0, alpha=0.7, beta=0.1), - ), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='rmsprop', opt_kwargs=dict(learning_rate=1.0, momentum=0.9)), - dict(opt_name='adabelief', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='radam', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='sm3', opt_kwargs=dict(learning_rate=1.0)), - dict(opt_name='yogi', opt_kwargs=dict(learning_rate=1.0, b1=0.99)), -) - - -def _setup_mixture(dtype): - initial_params = jnp.array([-0.4, -0.4], dtype=dtype) - final_params = jnp.array([2.0, 0.0], dtype=dtype) - - @jax.grad - def get_updates(params): - x, y = params - return -jnp.exp(-((x - 2) ** 2) - y**2) - 1.0 * jnp.exp( - -((x) ** 2 + (y) ** 2 * 100) - ) - - return initial_params, final_params, get_updates +_BASE_OPTIMIZERS_UNDER_TEST = [ + dict(base_opt_name='sgd', base_opt_kwargs=dict(learning_rate=1e-3)), +] +_ADVERSARIAL_OPTIMIZERS_UNDER_TEST = [ + dict(adv_opt_name='sgd', adv_opt_kwargs=dict(learning_rate=1e-5)), + dict(adv_opt_name='adam', adv_opt_kwargs=dict(learning_rate=1e-4)), +] def _setup_parabola(dtype): @@ -84,106 +47,45 @@ def get_updates(params): return initial_params, final_params, get_updates -def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 - - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) - - @jax.grad - def get_updates(params): - return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( - params[1] - params[0] ** 2 - ) - - return initial_params, final_params, get_updates - - -class OptimizerTestState(NamedTuple): - """Inner optimizer state for the SAM tests.""" - - aggregate_grads: base.Params - - -def _test_optimizer(step_size: float) -> base.GradientTransformation: - """Inner optimizer for the SAM tests.""" - - # Use SGD for simplicity but add non-trivial optimizer state so that the - # resetting behaviour of SAM can be tested. - def init_fn(params): - aggregate_grads = jax.tree.map(jnp.zeros_like, params) - return OptimizerTestState(aggregate_grads) - - def update_fn(updates, state, params): - # The test optimizer does not use the parameters, but we check that they - # have been passed correctly. - chex.assert_trees_all_equal_shapes(updates, params) - aggregate_grads = update.apply_updates(state.aggregate_grads, updates) - updates = jax.tree.map(lambda u: step_size * u, updates) - return updates, OptimizerTestState(aggregate_grads) - - return base.GradientTransformation(init_fn, update_fn) - - class SAMTest(chex.TestCase): @parameterized.product( - _OPTIMIZERS_UNDER_TEST, - sync_period=(2,), - target=(_setup_parabola,), - dtype=(jnp.float32,), - ) - def test_optimization(self, opt_name, opt_kwargs, sync_period, target, dtype): - opt = alias.sgd(0.003) - adv_opt = combine.chain( - _sam.normalize(), getattr(alias, opt_name)(**opt_kwargs) - ) - opt = _sam.sam(opt, adv_opt, sync_period=sync_period) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - updates = get_updates(params) - updates, state = opt.update(updates, state, params) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) - - for _ in range(25000 * sync_period): - params, state = step(params, state) - - chex.assert_trees_all_close(params, final_params, rtol=3e-2, atol=3e-2) - - @parameterized.product( - _OPTIMIZERS_UNDER_TEST, + _BASE_OPTIMIZERS_UNDER_TEST, + _ADVERSARIAL_OPTIMIZERS_UNDER_TEST, sync_period=(2,), target=(_setup_parabola,), - dtype=(jnp.float32,), + dtype=('float32',), + opaque_mode=(False, True), ) - def test_opaque_optimization( - self, opt_name, opt_kwargs, sync_period, target, dtype + def test_optimization( + self, + base_opt_name, + base_opt_kwargs, + adv_opt_name, + adv_opt_kwargs, + sync_period, + target, + dtype, + opaque_mode, ): - base_opt = alias.sgd(0.003) + dtype = jnp.dtype(dtype) + base_opt = getattr(alias, base_opt_name)(**base_opt_kwargs) adv_opt = combine.chain( - _sam.normalize(), getattr(alias, opt_name)(**opt_kwargs) + _sam.normalize(), getattr(alias, adv_opt_name)(**adv_opt_kwargs) ) opt = _sam.sam( - base_opt, adv_opt, sync_period=sync_period, opaque_mode=True + base_opt, adv_opt, sync_period=sync_period, opaque_mode=opaque_mode ) initial_params, final_params, get_updates = target(dtype) + if opaque_mode: + update_kwargs = dict(grad_fn=lambda p, _: get_updates(p)) + else: + update_kwargs = {} @jax.jit def step(params, state): updates = get_updates(params) - updates, state = opt.update( - updates, state, params, grad_fn=lambda p, _: get_updates(p) - ) + updates, state = opt.update(updates, state, params, **update_kwargs) params = update.apply_updates(params, updates) return params, state diff --git a/optax/contrib/_schedule_free.py b/optax/contrib/_schedule_free.py index 3edcc9270..658b17d87 100644 --- a/optax/contrib/_schedule_free.py +++ b/optax/contrib/_schedule_free.py @@ -22,9 +22,11 @@ from optax._src import alias from optax._src import base from optax._src import combine +from optax._src import numerics from optax._src import transform from optax.schedules import _schedule from optax.transforms import _adding +import optax.tree_utils as otu class ScheduleFreeState(NamedTuple): @@ -38,10 +40,18 @@ class ScheduleFreeState(NamedTuple): z: base.Params -def schedule_free_eval_params(state: ScheduleFreeState, params: base.Params): +def schedule_free_eval_params(state: base.OptState, params: base.Params): """Params for evaluation of :func:`optax.contrib.schedule_free`.""" + # Using ScheduleFreeState as a type hint above results in pytype errors in + # tests. + b1 = getattr(state, 'b1') + z = getattr(state, 'z') + if b1 is None or z is None: + raise ValueError( + 'schedule_free_eval_params requires a ScheduleFreeState as input.' + ) return jax.tree.map( - lambda yi, zi: (yi - (1.0 - state.b1) * zi) / state.b1, params, state.z + lambda yi, zi: (yi - (1.0 - b1) * zi) / b1, params, z ) @@ -50,7 +60,7 @@ def schedule_free( learning_rate: base.ScalarOrSchedule, b1: float = 0.9, weight_lr_power: float = 2.0, - state_dtype=jnp.float32, + state_dtype: Optional[jax.typing.DTypeLike] = None, ) -> base.GradientTransformationExtraArgs: r"""Turn base_optimizer schedule_free. @@ -108,6 +118,10 @@ def schedule_free( Defazio et al, `Schedule-Free Learning - A New Way to Train `_, 2024 + .. warning:: + The current implementation requires the parameter ``b1`` to be strictly + positive. + Args: base_optimizer: Base optimizer to compute updates from. learning_rate: learning_rate schedule w/o decay but with warmup. @@ -122,15 +136,20 @@ def schedule_free( base_optimizer = base.with_extra_args_support(base_optimizer) def init_fn(params: base.Params) -> ScheduleFreeState: - if b1 == 0: - raise ValueError( - 'The current implementation of schedule_free requires b1 > 0.') - z = jax.tree.map(lambda t: t.astype(state_dtype), params) + # Define state parameters with the lowest dtype of the parameters to avoid + # dtype promotion of parameters resulting in a dtype mismatch between + # parameters and updates. + params_dtype = otu.tree_dtype(params, 'lowest') + if state_dtype is not None: + otu.tree_assert_dtype_preserved(params, state_dtype) + z = otu.tree_cast(params, dtype=state_dtype) + else: + z = params return ScheduleFreeState( - b1=jnp.array(b1, dtype=jnp.float32), - weight_sum=jnp.zeros([], dtype=jnp.float32), + b1=jnp.asarray(b1, dtype=params_dtype), + weight_sum=jnp.zeros([], dtype=params_dtype), step_count=jnp.ones([], dtype=jnp.int32), - max_lr=jnp.zeros([], dtype=jnp.float32), + max_lr=jnp.zeros([], dtype=params_dtype), base_optimizer_state=base_optimizer.init(params), z=z, ) @@ -143,10 +162,12 @@ def update_fn( ): lr = learning_rate if callable(learning_rate): - lr = learning_rate(state.step_count) + lr = jnp.asarray( + learning_rate(state.step_count), dtype=state.max_lr.dtype + ) max_lr = jnp.maximum(state.max_lr, lr) - next_step_count = state.step_count + 1 + next_step_count = numerics.safe_increment(state.step_count) weight = max_lr**weight_lr_power next_total_weight = state.weight_sum + weight @@ -190,7 +211,7 @@ def update_fn( ) next_state = ScheduleFreeState( - b1=jnp.array(b1, dtype=jnp.float32), + b1=state.b1, weight_sum=next_total_weight, step_count=next_step_count, max_lr=max_lr, @@ -205,12 +226,11 @@ def update_fn( def schedule_free_sgd( learning_rate: float = 1.0, - *, - warmup_steps: int = 0, + warmup_steps: Optional[int] = None, b1: float = 0.9, - weight_decay: float = 0.0, + weight_decay: Optional[float] = None, weight_lr_power: float = 2.0, - state_dtype=jnp.float32, + state_dtype: Optional[jax.typing.DTypeLike] = None, ) -> base.GradientTransformationExtraArgs: """Schedule-Free wrapper for SGD. @@ -258,14 +278,14 @@ def schedule_free_sgd( Objective function: 8.06E-01 Objective function: 2.41E-01 """ - if warmup_steps > 0: + if warmup_steps is not None: learning_rate = _schedule.warmup_constant_schedule( init_value=0, peak_value=learning_rate, warmup_steps=warmup_steps, ) optimizer = alias.sgd(learning_rate) - if weight_decay > 0: + if weight_decay is not None: optimizer = combine.chain( _adding.add_decayed_weights(weight_decay), optimizer) return schedule_free( @@ -279,14 +299,13 @@ def schedule_free_sgd( def schedule_free_adamw( learning_rate: float = 0.0025, - *, - warmup_steps: int = 0, + warmup_steps: Optional[int] = None, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, weight_decay: float = 0.0, weight_lr_power: float = 2.0, - state_dtype=jnp.float32, + state_dtype: Optional[jax.typing.DTypeLike] = None, ) -> base.GradientTransformationExtraArgs: """Schedule-Free wrapper for AdamW. @@ -333,7 +352,7 @@ def schedule_free_adamw( Objective function: 8.94E-01 Objective function: 4.13E-01 """ - if warmup_steps > 0: + if warmup_steps is not None: learning_rate = _schedule.warmup_constant_schedule( init_value=0, peak_value=learning_rate, diff --git a/optax/contrib/_schedule_free_test.py b/optax/contrib/_schedule_free_test.py index 7c5197d10..dbc096f02 100644 --- a/optax/contrib/_schedule_free_test.py +++ b/optax/contrib/_schedule_free_test.py @@ -19,23 +19,10 @@ import chex import jax import jax.numpy as jnp -import numpy as np from optax._src import alias from optax._src import numerics -from optax._src import schedule from optax._src import update from optax.contrib import _schedule_free -from optax.tree_utils import _state_utils - - -_WARM_LR = schedule.warmup_constant_schedule(0.0, 1e-2, 5_000) - -# TODO(harshm): try other optimizers with schedule_free. -_OPTIMIZERS_UNDER_TEST = ( - dict(opt_name='sgd', opt_kwargs=dict(momentum=0.0)), - dict(opt_name='adam', opt_kwargs=dict(b1=0.0)), - dict(opt_name='adamw', opt_kwargs=dict(b1=0.0)), -) def _setup_parabola(dtype): @@ -50,79 +37,24 @@ def get_updates(params): return initial_params, final_params, get_updates -def _setup_rosenbrock(dtype): - """Rosenbrock function as an optimization target.""" - a = 1.0 - b = 100.0 - - initial_params = jnp.array([0.0, 0.0], dtype=dtype) - final_params = jnp.array([a, a**2], dtype=dtype) - - @jax.grad - def get_updates(params): - return numerics.abs_sq(a - params[0]) + b * numerics.abs_sq( - params[1] - params[0] ** 2 - ) - - return initial_params, final_params, get_updates - - class ScheduleFreeTest(chex.TestCase): - def setUp(self): - super().setUp() - self.grads = {'x': np.array(2.0), 'y': np.array(-2.0)} - self.initial_params = {'x': np.array(3.0), 'y': np.array(-3.0)} - - @parameterized.product( - _OPTIMIZERS_UNDER_TEST, - target=(_setup_parabola, _setup_rosenbrock), - dtype=(jnp.float32,), - ) - def test_optimization(self, opt_name, opt_kwargs, target, dtype): - - opt = getattr(alias, opt_name)(learning_rate=_WARM_LR, **opt_kwargs) - opt = _schedule_free.schedule_free(opt, learning_rate=_WARM_LR) - initial_params, final_params, get_updates = target(dtype) - - @jax.jit - def step(params, state): - updates = get_updates(params) - updates, state = opt.update(updates, state, params) - params = update.apply_updates(params, updates) - return params, state - - params = initial_params - state = opt.init(params) - # A no-op change, to verify that tree map works. - state = _state_utils.tree_map_params(opt, lambda v: v, state) - - for _ in range(25000): - params, state = step(params, state) - - chex.assert_trees_all_close( - _schedule_free.schedule_free_eval_params(state, params), - final_params, - rtol=3e-2, - atol=3e-2, - ) - - @parameterized.parameters(*_OPTIMIZERS_UNDER_TEST) - def test_learning_rate_zero(self, opt_name, opt_kwargs): - opt = getattr(alias, opt_name)(learning_rate=0.0, **opt_kwargs) - opt = _schedule_free.schedule_free(opt, learning_rate=0.0) - initial_params, _, get_updates = _setup_parabola(jnp.float32) + def test_learning_rate_zero(self): + base_opt = alias.sgd(learning_rate=0.0, momentum=0.0) + opt = _schedule_free.schedule_free(base_opt, learning_rate=0.0) + initial_params = jnp.array([1., 2.]) + fun = lambda x: jnp.sum(x**2) @jax.jit def step(params, state): - updates = get_updates(params) + updates = jax.grad(fun)(params) updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state params = initial_params state = opt.init(params) - for _ in range(25000): + for _ in range(5): params, state = step(params, state) chex.assert_trees_all_close( @@ -131,14 +63,16 @@ def step(params, state): ) def test_schedule_free_adamw(self): + + initial_params = jnp.array([1., 2.]) + fun = lambda x: jnp.sum(x**2) + def step(params, state, opt): - updates = get_updates(params) + updates = jax.grad(fun)(params) updates, state = opt.update(updates, state, params) params = update.apply_updates(params, updates) return params, state - initial_params, _, get_updates = _setup_parabola(jnp.float32) - def run(opt): params = initial_params state = opt.init(params) @@ -164,22 +98,57 @@ def run(opt): params_wrapper = run(opt_wrapper) chex.assert_trees_all_close(params_shortcut, params_wrapper) - @parameterized.parameters(*_OPTIMIZERS_UNDER_TEST) - def test_scalar_preservance(self, opt_name, opt_kwargs): + def test_scalar_preservance(self): # Test whether the scalar arrays of shape () are preserved through # _schedule_free.schedule_free_eval_params. - base_opt = getattr(alias, opt_name)(learning_rate=0.0, **opt_kwargs) - opt = _schedule_free.schedule_free(base_opt, learning_rate=0.0) + base_opt = alias.sgd(learning_rate=1.0, momentum=0.0) + opt = _schedule_free.schedule_free(base_opt, learning_rate=1.0) params = jnp.ones((), dtype=jnp.float32) state = opt.init(params) - # NOTE(vroulet): disabling wrong-arg-types because the type checker thinks - # that the state is a generic NamedTuple rather than a ScheduleFreeState. - # pytype: disable=wrong-arg-types eval_params = _schedule_free.schedule_free_eval_params(state, params) - # pytype: enable=wrong-arg-types chex.assert_equal_shape([params, eval_params]) chex.assert_trees_all_equal_dtypes(params, eval_params) + @parameterized.product( + params_dtype=('bfloat16', 'float32', 'complex64', None), + state_dtype=('bfloat16', 'float32', 'complex64', None), + ) + def test_explicit_dtype(self, params_dtype, state_dtype): + base_opt = alias.sgd(learning_rate=1.0, momentum=0.0) + opt = _schedule_free.schedule_free( + base_opt, learning_rate=1.0, state_dtype=state_dtype + ) + + params_dtype = jax.dtypes.canonicalize_dtype(params_dtype) + params = jnp.array([0.0, 0.0], dtype=params_dtype) + state_has_lower_dtype = ( + jnp.promote_types(params_dtype, state_dtype) + == params_dtype + ) + if state_dtype is None or state_has_lower_dtype: + state = opt.init(params) + + with self.subTest('Test that attribute dtype is correct'): + if state_dtype is None: + expected_dtype = params_dtype + else: + expected_dtype = jax.dtypes.canonicalize_dtype(state_dtype) + self.assertEqual(expected_dtype, getattr(state, 'z').dtype) + + with self.subTest( + 'Verifies that the updates keep the same type as params' + ): + updates, _ = opt.update(jnp.ones_like(params), state, params) + self.assertEqual(getattr(updates, 'dtype'), params.dtype) + else: + with self.subTest( + 'Test that we forbid setting dtype s.t. updates dtype get promoted to' + ' the state dtype' + ): + with self.assertRaises(ValueError): + opt.init(params) + + if __name__ == '__main__': absltest.main() diff --git a/optax/transforms/_accumulation.py b/optax/transforms/_accumulation.py index 25ae2123c..b73b35dfe 100644 --- a/optax/transforms/_accumulation.py +++ b/optax/transforms/_accumulation.py @@ -52,11 +52,17 @@ def trace( Returns: A `GradientTransformation` object. + + Raises: + ValueError: If the selected ``accumulator_dtype`` induces a dtype promotion + of the dtypes of the parameters. """ accumulator_dtype = utils.canonicalize_dtype(accumulator_dtype) def init_fn(params): + if accumulator_dtype is not None: + otu.tree_assert_dtype_preserved(params, accumulator_dtype) return TraceState( trace=otu.tree_zeros_like(params, dtype=accumulator_dtype)) diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 44e2e195d..06aca15ef 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -15,7 +15,9 @@ """The tree_utils sub-package.""" # pylint: disable=g-importing-member +from optax.tree_utils._casting import tree_assert_dtype_preserved from optax.tree_utils._casting import tree_cast +from optax.tree_utils._casting import tree_dtype from optax.tree_utils._random import tree_random_like from optax.tree_utils._random import tree_split_key_like from optax.tree_utils._state_utils import NamedTupleKey diff --git a/optax/tree_utils/_casting.py b/optax/tree_utils/_casting.py index 83a33b032..07754445b 100644 --- a/optax/tree_utils/_casting.py +++ b/optax/tree_utils/_casting.py @@ -14,18 +14,242 @@ # ============================================================================== """Utilities to cast pytrees to specific dtypes.""" +import functools from typing import Optional import chex import jax +import jax.numpy as jnp def tree_cast( - tree: chex.ArrayTree, - dtype: Optional[chex.ArrayDType] + tree: chex.ArrayTree, dtype: Optional[chex.ArrayDType] ) -> chex.ArrayTree: - """Cast tree to given dtype, skip if None.""" + """Cast tree to given dtype, skip if None. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16) + {'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)} + + Args: + tree: the tree to cast. + dtype: the dtype to cast to, or None to skip. + + Returns: + the tree, with leaves casted to dtype. + """ if dtype is not None: return jax.tree.map(lambda t: t.astype(dtype), tree) else: return tree + + +def tree_dtype( + tree: chex.ArrayTree, mixed_dtype_handler: Optional[str] = None +) -> chex.ArrayDType: + """Fetch dtype of tree. + + If the tree is empty, returns the default dtype of JAX arrays. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_dtype(tree) + dtype('float32') + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_dtype(tree, 'lowest') + dtype('float16') + >>> optax.tree_utils.tree_dtype(tree, 'highest') + dtype('float32') + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)}, + ... 'c': jnp.array(2.0, dtype=jnp.uint32)} + >>> # optax.tree_utils.tree_dtype(tree, 'highest') + >>> # -> will throw an error because int32 and uint32 + >>> # cannot be promoted to one another. + >>> optax.tree_utils.tree_dtype(tree, 'promote') + dtype('int64') + + Args: + tree: the tree to fetch the dtype of. + mixed_dtype_handler: how to handle mixed dtypes in the tree. + + - If ``mixed_dtype_handler=None``, returns the common dtype of the leaves + of the tree if it exists, otherwise raises an error. + - If ``mixed_dtype_handler='promote'``, promotes the dtypes of the leaves + of the tree to a common promoted dtype using + :func:`jax.numpy.promote_types`. + - If ``mixed_dtype_handler='highest'`` or + ``mixed_dtype_handler='lowest'``, returns the highest/lowest dtype of + the leaves of the tree. We consider a partial ordering of dtypes as + ``dtype1 <= dtype2`` if ``dtype1`` is promoted to ``dtype2``, that is, + if ``jax.numpy.promote_types(dtype1, dtype2) == dtype2``. Since some + dtypes cannot be promoted to one another, this is not a total ordering, + and the 'highest' or 'lowest' options may not be applicable. These + options will throw an error if the dtypes of the leaves of the tree + cannot be promoted to one another. + + Returns: + the dtype of the tree. + + Raises: + ValueError: If ``mixed_dtype_handler`` is set to ``None`` and multiple + dtypes are found in the tree. + ValueError: If ``mixed_dtype_handler`` is set to ``'highest'`` or + ``'lowest'`` and some leaves' dtypes in the tree cannot be promoted to one + another. + + .. seealso:: :func:`jax.numpy.promote_types`, + `Type promotion semantics in JAX + `_ + + .. versionadded:: 0.2.4 + """ + leaves = jax.tree.leaves(tree) + if not leaves: + # If the tree is empty, we return the default dtype as given by JAX on + # empty lists. + return jnp.dtype(jnp.asarray(leaves)) + if mixed_dtype_handler is None: + dtype = jnp.asarray(leaves[0]).dtype + _tree_assert_all_dtypes_equal(tree, dtype) + return dtype + elif mixed_dtype_handler == 'promote': + promoted_dtype = functools.reduce( + jnp.promote_types, [jnp.asarray(x).dtype for x in leaves] + ) + return promoted_dtype + elif mixed_dtype_handler == 'highest': + highest_dtype = functools.reduce( + _higher_dtype, [jnp.asarray(x).dtype for x in leaves] + ) + return highest_dtype + elif mixed_dtype_handler == 'lowest': + lowest_dtype = functools.reduce( + _lower_dtype, [jnp.asarray(x).dtype for x in leaves] + ) + return lowest_dtype + else: + raise ValueError( + f'Invalid value for {mixed_dtype_handler=}, possible values are: None,' + ' "promote", "highest", "lowest".' + ) + + +def tree_assert_dtype_preserved( + tree: chex.ArrayTree, + dtype: chex.ArrayDType, +) -> None: + """Checks whether some elements of tree may be promoted to dtype. + + Some transformations like :func:`optax.scale_by_adam`, :func:`optax.trace` + allow the user to specify a dtype for some of the state's parameters (e.g. the + momentum term). This function checks that the specified dtype of the state's + parameters does not induce a dtype promotion of any of the parameters. That + way we can ensure that the dtype of the updates are consistent with the dtype + of the parameters. + + Args: + tree: the tree to check. + dtype: the dtype to check against. + + Raises: + ValueError: If any element of the tree is promoted to dtype. + + .. versionadded:: 0.2.4 + """ + + def _assert_dtype_preserved(path, x): + x_dtype = jnp.asarray(x).dtype + if jnp.promote_types(x_dtype, dtype) != x_dtype: + err_msg = ( + f'{dtype=} induces dtype promotion for {path} with dtype {x_dtype}.' + ) + return err_msg + + err_msgs = jax.tree.leaves( + jax.tree_util.tree_map_with_path(_assert_dtype_preserved, tree) + ) + err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None] + if err_msgs: + raise ValueError('\n'.join(err_msgs)) + + +def _tree_assert_all_dtypes_equal( + tree: chex.ArrayTree, dtype: chex.ArrayDType +) -> None: + """Checks that all leaves of the tree have the given dtype. + + Args: + tree: the tree to check. + dtype: the dtype to check against. + + Raises: + ValueError: If any element of the tree does not match the given dtype. + """ + + def _assert_dtypes_equal(path, x): + x_dtype = jnp.asarray(x).dtype + if x_dtype != dtype: + err_msg = f'Expected {dtype=} for {path} but got {x_dtype}.' + return err_msg + + err_msgs = jax.tree.leaves( + jax.tree_util.tree_map_with_path(_assert_dtypes_equal, tree) + ) + err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None] + if err_msgs: + raise ValueError('\n'.join(err_msgs)) + + +def _lower_dtype( + dtype1: chex.ArrayDType, dtype2: chex.ArrayDType +) -> chex.ArrayDType: + """Returns lower dtype among two dtypes, if any can be promoted to the other. + + Args: + dtype1: The first dtype to compare. + dtype2: The second dtype to compare. + + Returns: + The lowest of the two dtypes, if any can be promoted to the other. + + Raises: + ValueError: If none of the dtypes can be promoted to the other. + """ + if jnp.promote_types(dtype1, dtype2) == dtype1: + return dtype2 + elif jnp.promote_types(dtype1, dtype2) == dtype2: + return dtype1 + else: + raise ValueError( + f'Cannot compare dtype of {dtype1=} and {dtype2=}.' + f' Neither {dtype1} nor {dtype2} can be promoted to the other.' + ) + + +def _higher_dtype( + dtype1: chex.ArrayDType, dtype2: chex.ArrayDType +) -> chex.ArrayDType: + """Returns higher dtype among two dtypes, if any can be promoted to the other. + + Args: + dtype1: The first dtype to compare. + dtype2: The second dtype to compare. + + Returns: + The highest of the two dtypes, if any can be promoted to the other. + + Raises: + ValueError: If none of the dtypes can be promoted to the other. + """ + if _lower_dtype(dtype1, dtype2) == dtype1: + return dtype2 + else: + return dtype1 diff --git a/optax/tree_utils/_casting_test.py b/optax/tree_utils/_casting_test.py index 08c846206..1b9a546b3 100644 --- a/optax/tree_utils/_casting_test.py +++ b/optax/tree_utils/_casting_test.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for optax.tree_utils._casting.""" +"""Tests for tree utilities on data types.""" from absl.testing import absltest from absl.testing import parameterized - import jax import jax.numpy as jnp import numpy as np - from optax import tree_utils as otu @@ -41,9 +39,109 @@ def _build_tree(val1, val2): tree = _build_tree(b, c) tree = otu.tree_cast(tree, dtype=dtype) - jax.tree.map( - np.testing.assert_array_equal, tree, _build_tree(new_b, new_c) - ) + jax.tree.map(np.testing.assert_array_equal, tree, _build_tree(new_b, new_c)) + + def test_tree_dtype(self): + """Test fecthing data type of a tree.""" + + with self.subTest('Check that it returns the right dtype'): + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + dtype = otu.tree_dtype(tree) + self.assertEqual(dtype, jnp.float32) + + with self.subTest('Check that it raises an error if dtypes differ'): + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + self.assertRaises(ValueError, otu.tree_dtype, tree) + + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + + with self.subTest('Check that it works with lowest common dtype'): + dtype = otu.tree_dtype(tree, 'lowest') + self.assertEqual(dtype, jnp.bfloat16) + + with self.subTest('Check that it works with highest common dtype'): + dtype = otu.tree_dtype(tree, 'highest') + self.assertEqual(dtype, jnp.float32) + + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float16), + } + + with self.subTest('Check that it works when promoting mixed dtype'): + dtype = otu.tree_dtype(tree, 'promote') + self.assertEqual(dtype, jnp.float32) + + with self.subTest( + 'Check that it raises an error if no dtypes cannot be promoted to one' + ' another' + ): + self.assertRaises(ValueError, otu.tree_dtype, tree, 'lowest') + self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest') + + def test_tree_assert_dtype_preserved(self): + """Test utility asserting no promotion of data types in a tree for given data type.""" + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + + with self.subTest( + 'Check that it raises an error if given dtype induces promotion of at' + ' least one element.' + ): + with self.assertRaises(ValueError): + otu.tree_assert_dtype_preserved(tree, jnp.float32) + + with self.subTest( + 'Check that it runs fine if no element gets promoted by given dtype.' + ): + otu.tree_assert_dtype_preserved(tree, jnp.bfloat16) + + with self.subTest( + 'Check that it naturally succeeds when considering lowest common dtype.' + ): + otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'lowest')) + + with self.subTest( + 'Check that it naturally fails when considering highest common dtype.' + ): + with self.assertRaises(ValueError): + otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'highest')) + + with self.subTest('Check that it works with empty trees.'): + for tree in [(), {}, None]: + otu.tree_assert_dtype_preserved(tree, jnp.float32) + + @parameterized.named_parameters( + dict(testcase_name='empty_dict', tree={}), + dict(testcase_name='empty_list', tree=[]), + dict(testcase_name='empty_tuple', tree=()), + dict(testcase_name='empty_none', tree=None), + ) + def test_tree_dtype_utilities_with_empty_trees(self, tree): + """Test tree data type utilities on empty trees.""" + default_dtype = jnp.asarray(1.0).dtype + + with self.subTest('Check tree_dtype works with empty trees.'): + dtype = otu.tree_dtype(tree) + self.assertEqual(dtype, default_dtype) + + with self.subTest( + 'Check tree_assert_dtype_preserved succeeds with any dtype for' + ' empty trees.' + ): + # There is no array in the tree to check, so it should succeed. + otu.tree_assert_dtype_preserved(tree, jnp.complex64) if __name__ == '__main__':