diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index b199697f..c792944f 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -101,6 +101,7 @@ Tree tree_mul tree_ones_like tree_random_like + tree_split_key_like tree_scalar_mul tree_set tree_sub @@ -153,6 +154,10 @@ Tree ones like ~~~~~~~~~~~~~~ .. autofunction:: tree_ones_like +Tree with random keys +~~~~~~~~~~~~~~~~~~~~~~~ +.. autofunction:: tree_split_key_like + Tree with random values ~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_random_like diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index e89aef86..44e2e195 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -17,6 +17,7 @@ # pylint: disable=g-importing-member from optax.tree_utils._casting import tree_cast from optax.tree_utils._random import tree_random_like +from optax.tree_utils._random import tree_split_key_like from optax.tree_utils._state_utils import NamedTupleKey from optax.tree_utils._state_utils import tree_get from optax.tree_utils._state_utils import tree_get_all_with_path diff --git a/optax/tree_utils/_random.py b/optax/tree_utils/_random.py index 33783b2b..6b4fab30 100644 --- a/optax/tree_utils/_random.py +++ b/optax/tree_utils/_random.py @@ -21,7 +21,7 @@ from jax import tree_util as jtu -def _tree_rng_keys_split( +def tree_split_key_like( rng_key: chex.PRNGKey, target_tree: chex.ArrayTree ) -> chex.ArrayTree: """Split keys to match structure of target tree. @@ -67,7 +67,7 @@ def tree_random_like( .. versionadded:: 0.2.1 """ - keys_tree = _tree_rng_keys_split(rng_key, target_tree) + keys_tree = tree_split_key_like(rng_key, target_tree) return jtu.tree_map( lambda l, k: sampler(k, l.shape, dtype or l.dtype), target_tree, diff --git a/optax/tree_utils/_random_test.py b/optax/tree_utils/_random_test.py index 25ea580a..077ca678 100644 --- a/optax/tree_utils/_random_test.py +++ b/optax/tree_utils/_random_test.py @@ -22,6 +22,7 @@ import jax.numpy as jnp import jax.random as jrd import jax.tree_util as jtu +import numpy as np from optax import tree_utils as otu # We consider samplers with varying input dtypes, we do not test all possible @@ -48,6 +49,19 @@ def get_variable(type_var: str): class RandomTest(chex.TestCase): + def test_tree_split_key_like(self): + rng_key = jrd.PRNGKey(0) + tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}} + keys_tree = otu.tree_split_key_like(rng_key, tree) + + with self.subTest('Test structure matches'): + self.assertEqual(jtu.tree_structure(tree), jtu.tree_structure(keys_tree)) + + with self.subTest('Test random key split'): + fst = jnp.stack(jtu.tree_flatten(keys_tree)[0]) + snd = jrd.split(rng_key, jtu.tree_structure(tree).num_leaves) + np.testing.assert_array_equal(fst, snd) + @parameterized.product( _SAMPLER_DTYPES, type_var=['real_array', 'complex_array', 'pytree'],