diff --git a/flax/core/scope.py b/flax/core/scope.py index 2ad5370336..8848596dff 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -1189,19 +1189,24 @@ def _is_valid_variables(variables: VariableDict) -> bool: def _is_valid_rng(rng: Array): """Checks whether rng is a valid JAX PRNGKey, also handling custom prngs.""" - # New-style JAX KeyArrays have a base type. - if jax_config.jax_enable_custom_prng: # type: ignore[attr-defined] - if not isinstance(rng, jax.random.KeyArray): - return False - # Old-style JAX PRNGKeys are plain uint32 arrays. - else: - if not isinstance(rng, (np.ndarray, jnp.ndarray)): - return False - if ( - rng.shape != random.default_prng_impl().key_shape - or rng.dtype != jnp.uint32 - ): - return False + # This check is valid for either new-style or old-style PRNG keys + if not isinstance(rng, (np.ndarray, jnp.ndarray)): + return False + + # Handle new-style typed PRNG keys + if hasattr(jax.dtypes, 'prng_key'): # JAX 0.4.14 or newer + if jax.dtypes.issubdtype(rng.dtype, jax.dtypes.prng_key): + return True + elif hasattr(jax.random, 'KeyArray'): # Previous JAX versions + if isinstance(rng, jax.random.KeyArray): + return True + + # Handle old-style raw PRNG keys + if ( + rng.shape != random.default_prng_impl().key_shape + or rng.dtype != jnp.uint32 + ): + return False return True