Skip to content

Commit

Permalink
Update lift to use the new PRNGKey check to suppress deprecation warn…
Browse files Browse the repository at this point in the history
…ings

PiperOrigin-RevId: 566339580
  • Loading branch information
Flax Team committed Sep 18, 2023
1 parent 5d846a5 commit 830d335
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,9 @@


def tree_map_rngs(fn, tree):
"""Needed for mapping JAX random.* functions over KeyArray leaves."""
"""Needed for mapping JAX random.* functions over PRNGKey leaves."""
return jax.tree_util.tree_map(
fn, tree, is_leaf=lambda x: isinstance(x, random.KeyArray)
fn, tree, is_leaf=lambda x: jax.dtypes.issubdtype(x, jax.dtypes.prng_key)
)


Expand Down

0 comments on commit 830d335

Please sign in to comment.