From 27c06b62d408b7a9c09d2133478cdcf15e2b8da7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 24 Oct 2024 16:13:32 -0700 Subject: [PATCH] Avoid assert_array_equal on PRNG keys. This operates via conversion to np.array, which will soon be disallowed by https://github.com/jax-ml/jax/pull/24481. --- tests/core/core_lift_test.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index 276d78e988..5ff7e3e696 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import operator - import jax import numpy as np from absl.testing import absltest @@ -170,12 +168,11 @@ def body_fn(scope, c): ) self.assertEqual(vars['state']['acc'], x) self.assertEqual(c, 2 * x) - np.testing.assert_array_equal( + self.assertEqual( vars['state']['rng_params'][0], vars['state']['rng_params'][1] ) with jax.debug_key_reuse(False): - np.testing.assert_array_compare( - operator.__ne__, + self.assertNotEqual( vars['state']['rng_loop'][0], vars['state']['rng_loop'][1], )