Skip to content

Commit

Permalink
Add optax.tree_utils.tree_random_like.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Sep 17, 2024
1 parent ee63e45 commit b9bdf05
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 6 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Tree
tree_map_params
tree_mul
tree_ones_like
tree_random_split
tree_random_like
tree_scalar_mul
tree_set
Expand Down Expand Up @@ -153,6 +154,10 @@ Tree ones like
~~~~~~~~~~~~~~
.. autofunction:: tree_ones_like

Tree with random keys
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_random_split

Tree with random values
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_random_like
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from optax.tree_utils._casting import tree_cast
from optax.tree_utils._random import tree_random_like
from optax.tree_utils._random import tree_random_split
from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
Expand Down
11 changes: 5 additions & 6 deletions optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,9 @@

import chex
import jax
from jax import tree_util as jtu


def _tree_rng_keys_split(
def tree_random_split(
rng_key: chex.PRNGKey, target_tree: chex.ArrayTree
) -> chex.ArrayTree:
"""Split keys to match structure of target tree.
Expand All @@ -33,9 +32,9 @@ def _tree_rng_keys_split(
Returns:
a tree of rng keys.
"""
tree_def = jtu.tree_structure(target_tree)
tree_def = jax.tree.structure(target_tree)
keys = jax.random.split(rng_key, tree_def.num_leaves)
return jtu.tree_unflatten(tree_def, keys)
return jax.tree.unflatten(tree_def, keys)


def tree_random_like(
Expand Down Expand Up @@ -67,8 +66,8 @@ def tree_random_like(
.. versionadded:: 0.2.1
"""
keys_tree = _tree_rng_keys_split(rng_key, target_tree)
return jtu.tree_map(
keys_tree = tree_random_split(rng_key, target_tree)
return jax.tree.map(
lambda l, k: sampler(k, l.shape, dtype or l.dtype),
target_tree,
keys_tree,
Expand Down
8 changes: 8 additions & 0 deletions optax/tree_utils/_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ def get_variable(type_var: str):

class RandomTest(chex.TestCase):

def test_tree_random_split(self):
rng_key = jrd.PRNGKey(0)
tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}}
keys_tree = otu.tree_random_split(rng_key, tree)

with self.subTest('Test structure matches'):
self.assertEqual(jtu.tree_structure(tree), jtu.tree_structure(keys_tree))

@parameterized.product(
_SAMPLER_DTYPES,
type_var=['real_array', 'complex_array', 'pytree'],
Expand Down

0 comments on commit b9bdf05

Please sign in to comment.