Skip to content

Commit

Permalink
fix tests for numpy 2.0 compatibility
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 677662507
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Sep 23, 2024
1 parent d584551 commit fcd0071
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 24 deletions.
20 changes: 12 additions & 8 deletions tests/jax_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from functools import partial

from absl.testing import absltest
from absl.testing import parameterized
import chex
from flax import jax_utils
import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import parameterized

from flax import jax_utils

NDEV = 4

Expand All @@ -44,6 +44,7 @@ def test_basics(self, dtype, bs):
# Just tests that basic calling works without exploring caveats.
@partial(jax_utils.pad_shard_unpad, static_argnums=())
def add(a, b):
b = jnp.asarray(b, dtype=dtype)
return a + b

x = np.arange(bs, dtype=dtype)
Expand All @@ -58,7 +59,7 @@ def test_trees(self, dtype, bs):
def add(a, b):
return a['a'] + b[0]

x = np.arange(bs, dtype=dtype)
x = jnp.arange(bs, dtype=dtype)
y = add(dict(a=x), (10 * x,))
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
Expand All @@ -69,12 +70,13 @@ def test_min_device_batch_avoids_recompile(self, dtype):
@jax.jit
@chex.assert_max_traces(n=1)
def add(a, b):
b = jnp.asarray(b, dtype=dtype)
return a + b

chex.clear_trace_counter()

for bs in self.BATCH_SIZES:
x = np.arange(bs, dtype=dtype)
x = jnp.arange(bs, dtype=dtype)
y = add(x, 10 * x, min_device_batch=9) # pylint: disable=unexpected-keyword-arg
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10 * x))
Expand All @@ -83,9 +85,9 @@ def add(a, b):
def test_static_argnum(self, dtype, bs):
@partial(jax_utils.pad_shard_unpad, static_argnums=(1,))
def add(a, b):
return a + b
return a + jnp.asarray(b, dtype=dtype)

x = np.arange(bs, dtype=dtype)
x = jnp.arange(bs, dtype=dtype)
y = add(x, 10)
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(x + 10))
Expand All @@ -96,9 +98,11 @@ def test_static_argnames(self, dtype, bs):
# test the default/most canonical path where `params` are the first arg.
@partial(jax_utils.pad_shard_unpad, static_argnames=('b',))
def add(params, a, *, b):
params = jnp.asarray(params, dtype=dtype)
b = jnp.asarray(b, dtype=dtype)
return params * a + b

x = np.arange(bs, dtype=dtype)
x = jnp.arange(bs, dtype=dtype)
y = add(5, x, b=10)
chex.assert_type(y.dtype, x.dtype)
np.testing.assert_allclose(np.float64(y), np.float64(5 * x + 10))
Expand Down
18 changes: 2 additions & 16 deletions tests/serialization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,8 +197,7 @@ def __call__(self):
)
self.assertEqual(variables, deserialized_state)

@parameterized.parameters(
[
@parameterized.parameters([
'byte',
'b',
'ubyte',
Expand All @@ -222,11 +221,9 @@ def __call__(self):
'd',
'longdouble',
'g',
'cfloat',
'cdouble',
'clongdouble',
'm',
'bool8',
'b1',
'int64',
'i8',
Expand Down Expand Up @@ -259,26 +256,15 @@ def __call__(self):
'i1',
'uint8',
'u1',
'complex_',
'int0',
'uint0',
'single',
'csingle',
'singlecomplex',
'float_',
'intc',
'uintc',
'int_',
'longfloat',
'clongfloat',
'longcomplex',
'bool_',
'int',
'float',
'complex',
'bool',
]
)
])
def test_numpy_serialization(self, dtype):
np.random.seed(0)
if (
Expand Down

0 comments on commit fcd0071

Please sign in to comment.