diff --git a/docs/api/utilities.rst b/docs/api/utilities.rst index c792944fd..816f82eea 100644 --- a/docs/api/utilities.rst +++ b/docs/api/utilities.rst @@ -92,7 +92,9 @@ Tree NamedTupleKey tree_add tree_add_scalar_mul + tree_cast tree_div + tree_dtype tree_get tree_get_all_with_path tree_l1_norm @@ -122,6 +124,14 @@ Tree add and scalar multiply ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_add_scalar_mul +Tree cast +~~~~~~~~~ +.. autofunction:: tree_cast + +Tree data type +~~~~~~~~~~~~~~ +.. autofunction:: tree_dtype + Tree divide ~~~~~~~~~~~ .. autofunction:: tree_div @@ -154,8 +164,8 @@ Tree ones like ~~~~~~~~~~~~~~ .. autofunction:: tree_ones_like -Tree with random keys -~~~~~~~~~~~~~~~~~~~~~~~ +Split key according to structure of a tree +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: tree_split_key_like Tree with random values diff --git a/optax/tree_utils/__init__.py b/optax/tree_utils/__init__.py index 44e2e195d..06aca15ef 100644 --- a/optax/tree_utils/__init__.py +++ b/optax/tree_utils/__init__.py @@ -15,7 +15,9 @@ """The tree_utils sub-package.""" # pylint: disable=g-importing-member +from optax.tree_utils._casting import tree_assert_dtype_preserved from optax.tree_utils._casting import tree_cast +from optax.tree_utils._casting import tree_dtype from optax.tree_utils._random import tree_random_like from optax.tree_utils._random import tree_split_key_like from optax.tree_utils._state_utils import NamedTupleKey diff --git a/optax/tree_utils/_casting.py b/optax/tree_utils/_casting.py index 83a33b032..07754445b 100644 --- a/optax/tree_utils/_casting.py +++ b/optax/tree_utils/_casting.py @@ -14,18 +14,242 @@ # ============================================================================== """Utilities to cast pytrees to specific dtypes.""" +import functools from typing import Optional import chex import jax +import jax.numpy as jnp def tree_cast( - tree: chex.ArrayTree, - dtype: Optional[chex.ArrayDType] + tree: chex.ArrayTree, dtype: Optional[chex.ArrayDType] ) -> chex.ArrayTree: - """Cast tree to given dtype, skip if None.""" + """Cast tree to given dtype, skip if None. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16) + {'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)} + + Args: + tree: the tree to cast. + dtype: the dtype to cast to, or None to skip. + + Returns: + the tree, with leaves casted to dtype. + """ if dtype is not None: return jax.tree.map(lambda t: t.astype(dtype), tree) else: return tree + + +def tree_dtype( + tree: chex.ArrayTree, mixed_dtype_handler: Optional[str] = None +) -> chex.ArrayDType: + """Fetch dtype of tree. + + If the tree is empty, returns the default dtype of JAX arrays. + + Examples: + >>> import jax.numpy as jnp + >>> import optax + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_dtype(tree) + dtype('float32') + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)}, + ... 'c': jnp.array(2.0, dtype=jnp.float32)} + >>> optax.tree_utils.tree_dtype(tree, 'lowest') + dtype('float16') + >>> optax.tree_utils.tree_dtype(tree, 'highest') + dtype('float32') + >>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)}, + ... 'c': jnp.array(2.0, dtype=jnp.uint32)} + >>> # optax.tree_utils.tree_dtype(tree, 'highest') + >>> # -> will throw an error because int32 and uint32 + >>> # cannot be promoted to one another. + >>> optax.tree_utils.tree_dtype(tree, 'promote') + dtype('int64') + + Args: + tree: the tree to fetch the dtype of. + mixed_dtype_handler: how to handle mixed dtypes in the tree. + + - If ``mixed_dtype_handler=None``, returns the common dtype of the leaves + of the tree if it exists, otherwise raises an error. + - If ``mixed_dtype_handler='promote'``, promotes the dtypes of the leaves + of the tree to a common promoted dtype using + :func:`jax.numpy.promote_types`. + - If ``mixed_dtype_handler='highest'`` or + ``mixed_dtype_handler='lowest'``, returns the highest/lowest dtype of + the leaves of the tree. We consider a partial ordering of dtypes as + ``dtype1 <= dtype2`` if ``dtype1`` is promoted to ``dtype2``, that is, + if ``jax.numpy.promote_types(dtype1, dtype2) == dtype2``. Since some + dtypes cannot be promoted to one another, this is not a total ordering, + and the 'highest' or 'lowest' options may not be applicable. These + options will throw an error if the dtypes of the leaves of the tree + cannot be promoted to one another. + + Returns: + the dtype of the tree. + + Raises: + ValueError: If ``mixed_dtype_handler`` is set to ``None`` and multiple + dtypes are found in the tree. + ValueError: If ``mixed_dtype_handler`` is set to ``'highest'`` or + ``'lowest'`` and some leaves' dtypes in the tree cannot be promoted to one + another. + + .. seealso:: :func:`jax.numpy.promote_types`, + `Type promotion semantics in JAX + `_ + + .. versionadded:: 0.2.4 + """ + leaves = jax.tree.leaves(tree) + if not leaves: + # If the tree is empty, we return the default dtype as given by JAX on + # empty lists. + return jnp.dtype(jnp.asarray(leaves)) + if mixed_dtype_handler is None: + dtype = jnp.asarray(leaves[0]).dtype + _tree_assert_all_dtypes_equal(tree, dtype) + return dtype + elif mixed_dtype_handler == 'promote': + promoted_dtype = functools.reduce( + jnp.promote_types, [jnp.asarray(x).dtype for x in leaves] + ) + return promoted_dtype + elif mixed_dtype_handler == 'highest': + highest_dtype = functools.reduce( + _higher_dtype, [jnp.asarray(x).dtype for x in leaves] + ) + return highest_dtype + elif mixed_dtype_handler == 'lowest': + lowest_dtype = functools.reduce( + _lower_dtype, [jnp.asarray(x).dtype for x in leaves] + ) + return lowest_dtype + else: + raise ValueError( + f'Invalid value for {mixed_dtype_handler=}, possible values are: None,' + ' "promote", "highest", "lowest".' + ) + + +def tree_assert_dtype_preserved( + tree: chex.ArrayTree, + dtype: chex.ArrayDType, +) -> None: + """Checks whether some elements of tree may be promoted to dtype. + + Some transformations like :func:`optax.scale_by_adam`, :func:`optax.trace` + allow the user to specify a dtype for some of the state's parameters (e.g. the + momentum term). This function checks that the specified dtype of the state's + parameters does not induce a dtype promotion of any of the parameters. That + way we can ensure that the dtype of the updates are consistent with the dtype + of the parameters. + + Args: + tree: the tree to check. + dtype: the dtype to check against. + + Raises: + ValueError: If any element of the tree is promoted to dtype. + + .. versionadded:: 0.2.4 + """ + + def _assert_dtype_preserved(path, x): + x_dtype = jnp.asarray(x).dtype + if jnp.promote_types(x_dtype, dtype) != x_dtype: + err_msg = ( + f'{dtype=} induces dtype promotion for {path} with dtype {x_dtype}.' + ) + return err_msg + + err_msgs = jax.tree.leaves( + jax.tree_util.tree_map_with_path(_assert_dtype_preserved, tree) + ) + err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None] + if err_msgs: + raise ValueError('\n'.join(err_msgs)) + + +def _tree_assert_all_dtypes_equal( + tree: chex.ArrayTree, dtype: chex.ArrayDType +) -> None: + """Checks that all leaves of the tree have the given dtype. + + Args: + tree: the tree to check. + dtype: the dtype to check against. + + Raises: + ValueError: If any element of the tree does not match the given dtype. + """ + + def _assert_dtypes_equal(path, x): + x_dtype = jnp.asarray(x).dtype + if x_dtype != dtype: + err_msg = f'Expected {dtype=} for {path} but got {x_dtype}.' + return err_msg + + err_msgs = jax.tree.leaves( + jax.tree_util.tree_map_with_path(_assert_dtypes_equal, tree) + ) + err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None] + if err_msgs: + raise ValueError('\n'.join(err_msgs)) + + +def _lower_dtype( + dtype1: chex.ArrayDType, dtype2: chex.ArrayDType +) -> chex.ArrayDType: + """Returns lower dtype among two dtypes, if any can be promoted to the other. + + Args: + dtype1: The first dtype to compare. + dtype2: The second dtype to compare. + + Returns: + The lowest of the two dtypes, if any can be promoted to the other. + + Raises: + ValueError: If none of the dtypes can be promoted to the other. + """ + if jnp.promote_types(dtype1, dtype2) == dtype1: + return dtype2 + elif jnp.promote_types(dtype1, dtype2) == dtype2: + return dtype1 + else: + raise ValueError( + f'Cannot compare dtype of {dtype1=} and {dtype2=}.' + f' Neither {dtype1} nor {dtype2} can be promoted to the other.' + ) + + +def _higher_dtype( + dtype1: chex.ArrayDType, dtype2: chex.ArrayDType +) -> chex.ArrayDType: + """Returns higher dtype among two dtypes, if any can be promoted to the other. + + Args: + dtype1: The first dtype to compare. + dtype2: The second dtype to compare. + + Returns: + The highest of the two dtypes, if any can be promoted to the other. + + Raises: + ValueError: If none of the dtypes can be promoted to the other. + """ + if _lower_dtype(dtype1, dtype2) == dtype1: + return dtype2 + else: + return dtype1 diff --git a/optax/tree_utils/_casting_test.py b/optax/tree_utils/_casting_test.py index 08c846206..a8c2bc493 100644 --- a/optax/tree_utils/_casting_test.py +++ b/optax/tree_utils/_casting_test.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Tests for optax.tree_utils._casting.""" +"""Tests for tree utilities on data types.""" from absl.testing import absltest from absl.testing import parameterized - import jax import jax.numpy as jnp import numpy as np - from optax import tree_utils as otu @@ -41,9 +39,109 @@ def _build_tree(val1, val2): tree = _build_tree(b, c) tree = otu.tree_cast(tree, dtype=dtype) - jax.tree.map( - np.testing.assert_array_equal, tree, _build_tree(new_b, new_c) - ) + jax.tree.map(np.testing.assert_array_equal, tree, _build_tree(new_b, new_c)) + + def test_tree_dtype(self): + """Test fecthing data type of a tree.""" + + with self.subTest('Check that it returns the right dtype'): + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.float32)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + dtype = otu.tree_dtype(tree) + self.assertEqual(dtype, jnp.float32) + + with self.subTest('Check that it raises an error if dtypes differ'): + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + self.assertRaises(ValueError, otu.tree_dtype, tree) + + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + + with self.subTest('Check that it works with lowest common dtype'): + dtype = otu.tree_dtype(tree, 'lowest') + self.assertEqual(dtype, jnp.bfloat16) + + with self.subTest('Check that it works with highest common dtype'): + dtype = otu.tree_dtype(tree, 'highest') + self.assertEqual(dtype, jnp.float32) + + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float16), + } + + with self.subTest('Check that it works when promoting mixed dtype'): + dtype = otu.tree_dtype(tree, 'promote') + self.assertEqual(dtype, jnp.float32) + + with self.subTest( + 'Check that it raises an error if no dtypes cannot be promoted to one' + ' another' + ): + self.assertRaises(ValueError, otu.tree_dtype, tree, 'lowest') + self.assertRaises(ValueError, otu.tree_dtype, tree, 'highest') + + def test_tree_assert_dtype_preserved(self): + """Test asserting no promotion of dtypes in a tree for given dtype.""" + tree = { + 'a': {'b': jnp.array(1.0, dtype=jnp.bfloat16)}, + 'c': jnp.array(2.0, dtype=jnp.float32), + } + + with self.subTest( + 'Check that it raises an error if given dtype induces promotion of at' + ' least one element.' + ): + with self.assertRaises(ValueError): + otu.tree_assert_dtype_preserved(tree, jnp.float32) + + with self.subTest( + 'Check that it runs fine if no element gets promoted by given dtype.' + ): + otu.tree_assert_dtype_preserved(tree, jnp.bfloat16) + + with self.subTest( + 'Check that it naturally succeeds when considering lowest common dtype.' + ): + otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'lowest')) + + with self.subTest( + 'Check that it naturally fails when considering highest common dtype.' + ): + with self.assertRaises(ValueError): + otu.tree_assert_dtype_preserved(tree, otu.tree_dtype(tree, 'highest')) + + with self.subTest('Check that it works with empty trees.'): + for tree in [(), {}, None]: + otu.tree_assert_dtype_preserved(tree, jnp.float32) + + @parameterized.named_parameters( + dict(testcase_name='empty_dict', tree={}), + dict(testcase_name='empty_list', tree=[]), + dict(testcase_name='empty_tuple', tree=()), + dict(testcase_name='empty_none', tree=None), + ) + def test_tree_dtype_utilities_with_empty_trees(self, tree): + """Test tree data type utilities on empty trees.""" + default_dtype = jnp.asarray(1.0).dtype + + with self.subTest('Check tree_dtype works with empty trees.'): + dtype = otu.tree_dtype(tree) + self.assertEqual(dtype, default_dtype) + + with self.subTest( + 'Check tree_assert_dtype_preserved succeeds with any dtype for' + ' empty trees.' + ): + # There is no array in the tree to check, so it should succeed. + otu.tree_assert_dtype_preserved(tree, jnp.complex64) if __name__ == '__main__':