From e44155cdf1b575fe2dff7cfb29a184fec55432d5 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Thu, 12 Sep 2024 11:05:28 -0700 Subject: [PATCH] Add double precision tests for safe_increment and fix warnings on float64_test.py PiperOrigin-RevId: 673920430 --- optax/_src/float64_test.py | 33 ++++++++++++++++++++------------- optax/_src/numerics_test.py | 28 ++++++++++++++++------------ 2 files changed, 36 insertions(+), 25 deletions(-) diff --git a/optax/_src/float64_test.py b/optax/_src/float64_test.py index c3170d525..cde5f0360 100644 --- a/optax/_src/float64_test.py +++ b/optax/_src/float64_test.py @@ -16,12 +16,9 @@ from absl.testing import absltest from absl.testing import parameterized - import chex import jax -from jax import config import jax.numpy as jnp - from optax._src import alias from optax._src import base from optax._src import clipping @@ -40,10 +37,16 @@ ('scale_by_stddev', transform.scale_by_stddev, {}), ('adam', transform.scale_by_adam, {}), ('scale', transform.scale, dict(step_size=3.0)), - ('add_decayed_weights', transform.add_decayed_weights, - dict(weight_decay=0.1)), - ('scale_by_schedule', transform.scale_by_schedule, - dict(step_size_fn=lambda x: x * 0.1)), + ( + 'add_decayed_weights', + transform.add_decayed_weights, + dict(weight_decay=0.1), + ), + ( + 'scale_by_schedule', + transform.scale_by_schedule, + dict(step_size_fn=lambda x: x * 0.1), + ), ('scale_by_trust_ratio', transform.scale_by_trust_ratio, {}), ('add_noise', transform.add_noise, dict(eta=1.0, gamma=0.1, seed=42)), ('apply_every_k', transform.apply_every, {}), @@ -69,25 +72,29 @@ def _assert_dtype_equals(self, tree1, tree2): @chex.all_variants @parameterized.named_parameters(ALL_MODULES) def test_mixed_dtype_input_outputs(self, transform_constr, transform_kwargs): + jax.config.update('jax_enable_x64', True) initial_params = ( - jnp.array([1., 2.], dtype=jnp.float32), - jnp.array([3., 4.], dtype=jnp.float64)) + jnp.array([1.0, 2.0], dtype=jnp.float32), + jnp.array([3.0, 4.0], dtype=jnp.float64), + ) updates = ( - jnp.array([10., 21.], dtype=jnp.float32), - jnp.array([33., 42.], dtype=jnp.float64)) + jnp.array([10.0, 21.0], dtype=jnp.float32), + jnp.array([33.0, 42.0], dtype=jnp.float64), + ) scaler = transform_constr(**transform_kwargs) init_fn = self.variant(scaler.init) update_fn = self.variant(scaler.update) initial_state = init_fn(initial_params) updates, new_state = update_fn( - updates, initial_state, params=initial_params) + updates, initial_state, params=initial_params + ) new_params = update.apply_updates(initial_params, updates) self._assert_dtype_equals(initial_state, new_state) self._assert_dtype_equals(initial_params, new_params) + jax.config.update('jax_enable_x64', False) if __name__ == '__main__': - config.update('jax_enable_x64', True) absltest.main() diff --git a/optax/_src/numerics_test.py b/optax/_src/numerics_test.py index 20f9caba8..0bfcfbdeb 100644 --- a/optax/_src/numerics_test.py +++ b/optax/_src/numerics_test.py @@ -44,19 +44,21 @@ def _invalid_ord_axis_inputs(ord_axis_keepdims): class NumericsTest(chex.TestCase): @chex.all_variants - @parameterized.product( - str_dtype=[ - "bfloat16", - "float16", - "float32", - "int8", - "int16", - "int32", - ] - ) - def test_safe_increment(self, str_dtype): + @parameterized.parameters(*( + "bfloat16", + "float16", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + )) + def test_safe_increment(self, dtype): """Tests that safe_increment works for all dtypes.""" - dtype = jnp.dtype(str_dtype) + if dtype in ["float64", "int64"]: + jax.config.update("jax_enable_x64", True) + dtype = jnp.dtype(dtype) inc_fn = self.variant(numerics.safe_increment) with self.subTest("Increments correctly"): @@ -75,6 +77,8 @@ def test_safe_increment(self, str_dtype): base = jnp.asarray(max_val, dtype=dtype) incremented = inc_fn(base) np.testing.assert_array_equal(incremented, base) + if dtype in ["float64", "int64"]: + jax.config.update("jax_enable_x64", False) @parameterized.product( str_dtype=[