Skip to content

Commit

Permalink
Merge pull request #1063 from carlosgmartin:tree_random_split
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 676955605
  • Loading branch information
OptaxDev committed Sep 20, 2024
2 parents 2a88336 + a9603c9 commit fad9a20
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 2 deletions.
5 changes: 5 additions & 0 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Tree
tree_mul
tree_ones_like
tree_random_like
tree_split_key_like
tree_scalar_mul
tree_set
tree_sub
Expand Down Expand Up @@ -153,6 +154,10 @@ Tree ones like
~~~~~~~~~~~~~~
.. autofunction:: tree_ones_like

Tree with random keys
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_split_key_like

Tree with random values
~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_random_like
Expand Down
1 change: 1 addition & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# pylint: disable=g-importing-member
from optax.tree_utils._casting import tree_cast
from optax.tree_utils._random import tree_random_like
from optax.tree_utils._random import tree_split_key_like
from optax.tree_utils._state_utils import NamedTupleKey
from optax.tree_utils._state_utils import tree_get
from optax.tree_utils._state_utils import tree_get_all_with_path
Expand Down
4 changes: 2 additions & 2 deletions optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from jax import tree_util as jtu


def _tree_rng_keys_split(
def tree_split_key_like(
rng_key: chex.PRNGKey, target_tree: chex.ArrayTree
) -> chex.ArrayTree:
"""Split keys to match structure of target tree.
Expand Down Expand Up @@ -67,7 +67,7 @@ def tree_random_like(
.. versionadded:: 0.2.1
"""
keys_tree = _tree_rng_keys_split(rng_key, target_tree)
keys_tree = tree_split_key_like(rng_key, target_tree)
return jtu.tree_map(
lambda l, k: sampler(k, l.shape, dtype or l.dtype),
target_tree,
Expand Down
14 changes: 14 additions & 0 deletions optax/tree_utils/_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import jax.numpy as jnp
import jax.random as jrd
import jax.tree_util as jtu
import numpy as np
from optax import tree_utils as otu

# We consider samplers with varying input dtypes, we do not test all possible
Expand All @@ -48,6 +49,19 @@ def get_variable(type_var: str):

class RandomTest(chex.TestCase):

def test_tree_split_key_like(self):
rng_key = jrd.PRNGKey(0)
tree = {'a': jnp.zeros(2), 'b': {'c': [jnp.ones(3), jnp.zeros([4, 5])]}}
keys_tree = otu.tree_split_key_like(rng_key, tree)

with self.subTest('Test structure matches'):
self.assertEqual(jtu.tree_structure(tree), jtu.tree_structure(keys_tree))

with self.subTest('Test random key split'):
fst = jnp.stack(jtu.tree_flatten(keys_tree)[0])
snd = jrd.split(rng_key, jtu.tree_structure(tree).num_leaves)
np.testing.assert_array_equal(fst, snd)

@parameterized.product(
_SAMPLER_DTYPES,
type_var=['real_array', 'complex_array', 'pytree'],
Expand Down

0 comments on commit fad9a20

Please sign in to comment.