From ba042c542da52640ba6e29a9c47798ed6a3990a5 Mon Sep 17 00:00:00 2001 From: Fabian Pedregosa Date: Wed, 7 Feb 2024 20:30:54 +0000 Subject: [PATCH] FIX test failing after latest jax release Workaround for https://github.com/google/jax/issues/19713 --- .gitignore | 1 + optax/_src/alias_test.py | 4 ++-- optax/_src/wrappers_test.py | 8 ++++---- optax/contrib/common_test.py | 4 ++-- test.sh | 4 ++++ 5 files changed, 13 insertions(+), 8 deletions(-) diff --git a/.gitignore b/.gitignore index 030368d1..11a35419 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build/ dist/ venv/ +_testing/ # Building the documentation docs/_autosummary diff --git a/optax/_src/alias_test.py b/optax/_src/alias_test.py index 8e11b197..f80d9b64 100644 --- a/optax/_src/alias_test.py +++ b/optax/_src/alias_test.py @@ -158,8 +158,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( else: opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) - params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))] - grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))] + params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] + grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] state = self.variant(opt.init)(params) updates, new_state = self.variant(opt.update)(grads, state, params) diff --git a/optax/_src/wrappers_test.py b/optax/_src/wrappers_test.py index c8fee120..a567802a 100644 --- a/optax/_src/wrappers_test.py +++ b/optax/_src/wrappers_test.py @@ -364,21 +364,21 @@ def test_multi_steps_skip_not_finite(self): updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) + np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_inf_updates'): updates, opt_state = opt_update( dict(a=jnp.array(float('inf'))), opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) + np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_nan_updates'): updates, opt_state = opt_update( dict(a=jnp.full([], float('nan'))), opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) # No increase in mini_step params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.ones([])) + np.testing.assert_array_equal(params['a'], jnp.negative(jnp.ones([]))) with self.subTest('test_final_good_updates'): updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) @@ -387,7 +387,7 @@ def test_multi_steps_skip_not_finite(self): updates, opt_state = opt_update(dict(a=jnp.ones([])), opt_state, params) self.assertEqual(int(opt_state.mini_step), 0) params = update.apply_updates(params, updates) - np.testing.assert_array_equal(params['a'], -jnp.full([], 2.)) + np.testing.assert_array_equal(params['a'], jnp.negative(jnp.full([], 2.))) class MaskedTest(chex.TestCase): diff --git a/optax/contrib/common_test.py b/optax/contrib/common_test.py index 8ca588b3..01ac61c3 100644 --- a/optax/contrib/common_test.py +++ b/optax/contrib/common_test.py @@ -104,8 +104,8 @@ def test_optimizers_can_be_wrapped_in_inject_hyperparams( opt = opt_factory(**opt_kwargs) opt_inject = _inject.inject_hyperparams(opt_factory)(**opt_kwargs) - params = [-jnp.ones((2, 3)), jnp.ones((2, 5, 2))] - grads = [jnp.ones((2, 3)), -jnp.ones((2, 5, 2))] + params = [jnp.negative(jnp.ones((2, 3))), jnp.ones((2, 5, 2))] + grads = [jnp.ones((2, 3)), jnp.negative(jnp.ones((2, 5, 2)))] state = self.variant(opt.init)(params) updates, new_state = self.variant(opt.update)(grads, state, params) diff --git a/test.sh b/test.sh index 84168195..4272d580 100755 --- a/test.sh +++ b/test.sh @@ -18,6 +18,7 @@ set -xeuo pipefail # Install deps in a virtual env. rm -rf _testing +rm -rf .pytype mkdir -p _testing readonly VENV_DIR="$(mktemp -d -p `pwd`/_testing optax-env.XXXXXXXX)" # in the unlikely case in which there was something in that directory @@ -82,6 +83,9 @@ cd docs && make html make doctest cd .. +# cleanup +rm -rf _testing + set +u deactivate echo "All tests passed. Congrats!"