diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index c3170d52..43d39622 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -16,15 +16,14 @@ from absl.testing import absltest from absl.testing import parameterized - import chex import jax -from jax import config import jax.numpy as jnp - +import numpy as np from optax._src import alias from optax._src import base from optax._src import clipping +from optax._src import numerics from optax._src import transform from optax._src import update @@ -40,10 +39,16 @@ ('scale_by_stddev', transform.scale_by_stddev, {}), ('adam', transform.scale_by_adam, {}), ('scale', transform.scale, dict(step_size=3.0)), - ('add_decayed_weights', transform.add_decayed_weights, - dict(weight_decay=0.1)), - ('scale_by_schedule', transform.scale_by_schedule, - dict(step_size_fn=lambda x: x * 0.1)), + ( + 'add_decayed_weights', + transform.add_decayed_weights, + dict(weight_decay=0.1), + ), + ( + 'scale_by_schedule', + transform.scale_by_schedule, + dict(step_size_fn=lambda x: x * 0.1), + ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), ('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)), ('apply_every_k', transform.apply_every, {}), @@ -69,25 +74,56 @@ def _assert_dtype_equals(self, tree1, tree2): @chex.all_variants @parameterized.named_parameters(ALL_MODULES) def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs): + if not jax.config.jax_enable_x64: + raise self.skipTest('jax_enable_x64 is not set') initial_params = ( - jnp.array([1., 2.], dtype=jnp.float32), - jnp.array([3., 4.], dtype=jnp.float64)) + jnp.array([1.0, 2.0], dtype=jnp.float32), + jnp.array([3.0, 4.0], dtype=jnp.float64), + ) updates = ( - jnp.array([10., 21.], dtype=jnp.float32), - jnp.array([33., 42.], dtype=jnp.float64)) + jnp.array([10.0, 21.0], dtype=jnp.float32), + jnp.array([33.0, 42.0], dtype=jnp.float64), + ) scaler = transform_constr(**transform_kwargs) init_fn = self.variant(scaler.init) update_fn = self.variant(scaler.update) initial_state = init_fn(initial_params) updates, new_state = update_fn( - updates, initial_state, params=initial_params) + updates, initial_state, params=initial_params + ) new_params = update.apply_updates(initial_params, updates) self._assert_dtype_equals(initial_state, new_state) self._assert_dtype_equals(initial_params, new_params) + @chex.all_variants + @parameterized.product(str_dtype=['float64', 'int64']) + def test_safe_increment(self, str_dtype): + """Tests that safe_increment works for all dtypes.""" + if not jax.config.jax_enable_x64: + raise self.skipTest('jax_enable_x64 is not set') + dtype = jnp.dtype(str_dtype) + inc_fn = self.variant(numerics.safe_increment) + + with self.subTest('Increments correctly'): + x = jnp.asarray(3, dtype=dtype) + incremented = inc_fn(x) + expected = jnp.asarray(4, dtype=dtype) + np.testing.assert_array_equal(incremented, expected) + + with self.subTest('Avoids overflow'): + if jnp.issubdtype(dtype, jnp.integer): + max_val = jnp.iinfo(dtype).max + elif jnp.issubdtype(dtype, jnp.floating): + max_val = jnp.finfo(dtype).max + else: + raise ValueError(f'Unsupported dtype: {dtype}') + x = jnp.asarray(max_val, dtype=dtype) + incremented = inc_fn(x) + np.testing.assert_array_equal(incremented, x) + if __name__ == '__main__': - config.update('jax_enable_x64', True) + jax.config.update('jax_enable_x64', True) absltest.main()