Skip to content

Commit

Permalink
Add dtype option to tree_random_like
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674482266
  • Loading branch information
vroulet authored and OptaxDev committed Sep 14, 2024
1 parent c0e4228 commit 469c878
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 72 deletions.
17 changes: 12 additions & 5 deletions optax/tree_utils/_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,12 @@
# ==============================================================================
"""Utilities to generate random pytrees."""

from typing import Callable
from typing import Callable, Optional

import chex
import jax
from jax import tree_util as jtu

from optax._src import base


def _tree_rng_keys_split(
rng_key: chex.PRNGKey, target_tree: chex.ArrayTree
Expand All @@ -44,15 +42,24 @@ def tree_random_like(
rng_key: chex.PRNGKey,
target_tree: chex.ArrayTree,
sampler: Callable[
[chex.PRNGKey, base.Shape], chex.Array
[chex.PRNGKey, chex.Shape, chex.ArrayDType], chex.Array
] = jax.random.normal,
dtype: Optional[chex.ArrayDType] = None,
) -> chex.ArrayTree:
"""Create tree with random entries of the same shape as target tree.
.. warning::
The possible dtypes may be limited by the sampler, for example
``jax.random.rademacher`` only supports integer dtypes and will raise an
error if the dtype of the target tree is not an integer or if the dtype
is not of integer type.
Args:
rng_key: the key for the random number generator.
target_tree: the tree whose structure to match. Leaves must be arrays.
sampler: the noise sampling function, by default ``jax.random.normal``.
dtype: the desired dtype for the random numbers, passed to ``sampler``. If
None, the dtype of the target tree is used if possible.
Returns:
a random tree with the same structure as ``target_tree``, whose leaves have
Expand All @@ -62,7 +69,7 @@ def tree_random_like(
"""
keys_tree = _tree_rng_keys_split(rng_key, target_tree)
return jtu.tree_map(
lambda l, k: sampler(k, l.shape),
lambda l, k: sampler(k, l.shape, dtype or l.dtype),
target_tree,
keys_tree,
)
128 changes: 61 additions & 67 deletions optax/tree_utils/_random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,89 +14,83 @@
# ==============================================================================
"""Tests for optax.tree_utils._random."""

from typing import Callable

from absl.testing import absltest
from absl.testing import parameterized
import chex
import jax
from jax import tree_util as jtu
import jax.numpy as jnp
import numpy as np

import jax.random as jrd
import jax.tree_util as jtu
from optax import tree_utils as otu

# We consider samplers with varying input dtypes, we do not test all possible
# samplers from `jax.random`.
_SAMPLER_DTYPES = (
dict(sampler=jrd.normal, dtype=None),
dict(sampler=jrd.normal, dtype='bfloat16'),
dict(sampler=jrd.normal, dtype='float32'),
dict(sampler=jrd.rademacher, dtype='int32'),
dict(sampler=jrd.bits, dtype='uint32'),
)

class RandomTest(absltest.TestCase):

def setUp(self):
super().setUp()
rng = np.random.RandomState(0)
def get_variable(type_var: str):
"""Get a variable of various shape."""
if type_var == 'real_array':
return jnp.asarray([1.0, 2.0])
if type_var == 'complex_array':
return jnp.asarray([1.0 + 1j * 2.0, 3.0 + 4j * 5.0])
if type_var == 'pytree':
pytree = {'k1': 1.0, 'k2': (2.0, 3.0), 'k3': jnp.asarray([4.0, 5.0])}
return jtu.tree_map(jnp.asarray, pytree)

self.rng_jax = jax.random.PRNGKey(0)

self.tree_a = (rng.randn(20, 10) + 1j * rng.randn(20, 10), rng.randn(20))
self.tree_b = (rng.randn(20, 10), rng.randn(20))
class RandomTest(chex.TestCase):

self.tree_a_dict = jtu.tree_map(
jnp.asarray,
(
1.0,
{'k1': 1.0, 'k2': (1.0, 1.0)},
1.0
)
)
self.tree_b_dict = jtu.tree_map(
jnp.asarray,
(
1.0,
{'k1': 2.0, 'k2': (3.0, 4.0)},
5.0
)
@parameterized.product(
_SAMPLER_DTYPES,
type_var=['real_array', 'complex_array', 'pytree'],
)
def test_tree_random_like(
self,
sampler: Callable[
[chex.PRNGKey, chex.Shape, chex.ArrayDType], chex.Array
],
dtype: str,
type_var: str,
):
"""Test that tree_random_like matches its flat counterpart."""
if dtype is not None:
dtype = jnp.dtype(dtype)
rng_key = jrd.PRNGKey(0)
target_tree = get_variable(type_var)

rand_tree = otu.tree_random_like(
rng_key, target_tree, sampler=sampler, dtype=dtype
)

self.array_a = rng.randn(20) + 1j * rng.randn(20)
self.array_b = rng.randn(20)
flat_tree, tree_def = jtu.tree_flatten(target_tree)

self.tree_a_dict_jax = jtu.tree_map(jnp.array, self.tree_a_dict)
self.tree_b_dict_jax = jtu.tree_map(jnp.array, self.tree_b_dict)
with self.subTest('Test structure matches'):
self.assertEqual(tree_def, jtu.tree_structure(rand_tree))

def test_tree_random_like(self, eps=1e-6):
"""Test for `tree_random_like`.
with self.subTest('Test tree_random_like matches flat random like'):
flat_rand_tree, _ = jtu.tree_flatten(rand_tree)
keys = jrd.split(rng_key, tree_def.num_leaves)
expected_flat_rand_tree = [
sampler(key, x.shape, dtype or x.dtype)
for key, x in zip(keys, flat_tree)
]
chex.assert_trees_all_close(flat_rand_tree, expected_flat_rand_tree)

Args:
eps: amount of noise.
with self.subTest('Test dtype are as expected'):
if dtype is not None:
for x in jtu.tree_leaves(rand_tree):
self.assertEqual(x.dtype, dtype)
else:
chex.assert_trees_all_equal_dtypes(rand_tree, target_tree)

Tests that `tree_random_like` generates a tree of the proper structure,
that it can be added to a target tree with a small multiplicative factor
without errors, and that the resulting addition is close to the original.
"""
rand_tree_a = otu.tree_random_like(self.rng_jax, self.tree_a)
rand_tree_b = otu.tree_random_like(self.rng_jax, self.tree_b)
rand_tree_a_dict = otu.tree_random_like(self.rng_jax, self.tree_a_dict_jax)
rand_tree_b_dict = otu.tree_random_like(self.rng_jax, self.tree_b_dict_jax)
rand_array_a = otu.tree_random_like(self.rng_jax, self.array_a)
rand_array_b = otu.tree_random_like(self.rng_jax, self.array_b)
sum_tree_a = otu.tree_add_scalar_mul(self.tree_a, eps, rand_tree_a)
sum_tree_b = otu.tree_add_scalar_mul(self.tree_b, eps, rand_tree_b)
sum_tree_a_dict = otu.tree_add_scalar_mul(self.tree_a_dict,
eps,
rand_tree_a_dict)
sum_tree_b_dict = otu.tree_add_scalar_mul(self.tree_b_dict,
eps,
rand_tree_b_dict)
sum_array_a = otu.tree_add_scalar_mul(self.array_a, eps, rand_array_a)
sum_array_b = otu.tree_add_scalar_mul(self.array_b, eps, rand_array_b)
tree_sums = [sum_tree_a,
sum_tree_b,
sum_tree_a_dict,
sum_tree_b_dict,
sum_array_a,
sum_array_b]
trees = [self.tree_a,
self.tree_b,
self.tree_a_dict,
self.tree_b_dict,
self.array_a,
self.array_b]
chex.assert_trees_all_close(trees, tree_sums, atol=1e-5)

if __name__ == '__main__':
absltest.main()

0 comments on commit 469c878

Please sign in to comment.