Skip to content

Commit

Permalink
Merge pull request #3314 from jakevdp:typed-prng
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564780788
  • Loading branch information
Flax Authors committed Sep 12, 2023
2 parents a5cf5db + 9a1496f commit 9aeb909
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,19 +1189,24 @@ def _is_valid_variables(variables: VariableDict) -> bool:

def _is_valid_rng(rng: Array):
"""Checks whether rng is a valid JAX PRNGKey, also handling custom prngs."""
# New-style JAX KeyArrays have a base type.
if jax_config.jax_enable_custom_prng: # type: ignore[attr-defined]
if not isinstance(rng, jax.random.KeyArray):
return False
# Old-style JAX PRNGKeys are plain uint32 arrays.
else:
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
return False
if (
rng.shape != random.default_prng_impl().key_shape
or rng.dtype != jnp.uint32
):
return False
# This check is valid for either new-style or old-style PRNG keys
if not isinstance(rng, (np.ndarray, jnp.ndarray)):
return False

# Handle new-style typed PRNG keys
if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer
if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key):
return True
elif hasattr(jax.random, 'KeyArray'): # Previous JAX versions
if isinstance(rng, jax.random.KeyArray):
return True

# Handle old-style raw PRNG keys
if (
rng.shape != random.default_prng_impl().key_shape
or rng.dtype != jnp.uint32
):
return False
return True


Expand Down

0 comments on commit 9aeb909

Please sign in to comment.