diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index b199697f6..01c90dda7 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -100,6 +100,7 @@ Tree tree_map_params tree_mul tree_ones_like + tree_random_split tree_random_like tree_scalar_mul tree_set @@ -153,6 +154,10 @@ Tree ones like ~~~~~~~~~~~~~~ .. autofunction:: tree_ones_like +Tree with random keys +~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_random_split + Tree with random values ~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_random_like diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index e89aef861..641e41446 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=g-importing-member from optax.tree_utils._casting import tree_cast from optax.tree_utils._random import tree_random_like +from optax.tree_utils._random import tree_random_split from optax.tree_utils._state_utils import NamedTupleKey from optax.tree_utils._state_utils import tree_get from optax.tree_utils._state_utils import tree_get_all_with_path diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index 33783b2b1..2b44e1a25 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -18,10 +18,9 @@ import chex import jax -from jax import tree_util as jtu -def _tree_rng_keys_split( +def tree_random_split( rng_key: chex.PRNGKey, target_tree: chex.ArrayTree ) -> chex.ArrayTree: """Split keys to match structure of target tree. @@ -33,9 +32,9 @@ def _tree_rng_keys_split( Returns: a tree of rng keys. """ - tree_def = jtu.tree_structure(target_tree) + tree_def = jax.tree.structure(target_tree) keys = jax.random.split(rng_key, tree_def.num_leaves) - return jtu.tree_unflatten(tree_def, keys) + return jax.tree.unflatten(tree_def, keys) def tree_random_like( @@ -67,8 +66,8 @@ def tree_random_like( .. versionadded:: 0.2.1 """ - keys_tree = _tree_rng_keys_split(rng_key, target_tree) - return jtu.tree_map( + keys_tree = tree_random_split(rng_key, target_tree) + return jax.tree.map( lambda l, k: sampler(k, l.shape, dtype or l.dtype), target_tree, keys_tree, diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 25ea580aa..4f06eca1a 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -48,6 +48,14 @@ def get_variable(type_var: str): class RandomTest(chex.TestCase): + def test_tree_random_split(self): + rng_key = jrd.PRNGKey(0) + tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}} + keys_tree = otu.tree_random_split(rng_key, tree) + + with self.subTest('Test structure matches'): + self.assertEqual(jtu.tree_structure(tree), jtu.tree_structure(keys_tree)) + @parameterized.product( _SAMPLER_DTYPES, type_var=['real_array', 'complex_array', 'pytree'],