Skip to content

Commit

Permalink
Add utility to fetch common/lowest/highest dtype of a tree.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675205179
  • Loading branch information
vroulet authored and OptaxDev committed Sep 23, 2024
1 parent b06f6c5 commit 11cc0ea
Show file tree
Hide file tree
Showing 4 changed files with 345 additions and 11 deletions.
14 changes: 12 additions & 2 deletions docs/api/utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ Tree
NamedTupleKey
tree_add
tree_add_scalar_mul
tree_cast
tree_div
tree_dtype
tree_get
tree_get_all_with_path
tree_l1_norm
Expand Down Expand Up @@ -122,6 +124,14 @@ Tree add and scalar multiply
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_add_scalar_mul

Tree cast
~~~~~~~~~
.. autofunction:: tree_cast

Tree data type
~~~~~~~~~~~~~~
.. autofunction:: tree_dtype

Tree divide
~~~~~~~~~~~
.. autofunction:: tree_div
Expand Down Expand Up @@ -154,8 +164,8 @@ Tree ones like
~~~~~~~~~~~~~~
.. autofunction:: tree_ones_like

Tree with random keys
~~~~~~~~~~~~~~~~~~~~~~~
Split key according to structure of a tree
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: tree_split_key_like

Tree with random values
Expand Down
2 changes: 2 additions & 0 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
"""The tree_utils sub-package."""

# pylint: disable=g-importing-member
from optax.tree_utils._casting import tree_assert_dtype_preserved
from optax.tree_utils._casting import tree_cast
from optax.tree_utils._casting import tree_dtype
from optax.tree_utils._random import tree_random_like
from optax.tree_utils._random import tree_split_key_like
from optax.tree_utils._state_utils import NamedTupleKey
Expand Down
230 changes: 227 additions & 3 deletions optax/tree_utils/_casting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,242 @@
# ==============================================================================
"""Utilities to cast pytrees to specific dtypes."""

import functools
from typing import Optional

import chex
import jax
import jax.numpy as jnp


def tree_cast(
tree: chex.ArrayTree,
dtype: Optional[chex.ArrayDType]
tree: chex.ArrayTree, dtype: Optional[chex.ArrayDType]
) -> chex.ArrayTree:
"""Cast tree to given dtype, skip if None."""
"""Cast tree to given dtype, skip if None.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
... 'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_cast(tree, dtype=jnp.bfloat16)
{'a': {'b': Array(1, dtype=bfloat16)}, 'c': Array(2, dtype=bfloat16)}
Args:
tree: the tree to cast.
dtype: the dtype to cast to, or None to skip.
Returns:
the tree, with leaves casted to dtype.
"""
if dtype is not None:
return jax.tree.map(lambda t: t.astype(dtype), tree)
else:
return tree


def tree_dtype(
tree: chex.ArrayTree, mixed_dtype_handler: Optional[str] = None
) -> chex.ArrayDType:
"""Fetch dtype of tree.
If the tree is empty, returns the default dtype of JAX arrays.
Examples:
>>> import jax.numpy as jnp
>>> import optax
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float32)},
... 'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree)
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.float16)},
... 'c': jnp.array(2.0, dtype=jnp.float32)}
>>> optax.tree_utils.tree_dtype(tree, 'lowest')
dtype('float16')
>>> optax.tree_utils.tree_dtype(tree, 'highest')
dtype('float32')
>>> tree = {'a': {'b': jnp.array(1.0, dtype=jnp.int32)},
... 'c': jnp.array(2.0, dtype=jnp.uint32)}
>>> # optax.tree_utils.tree_dtype(tree, 'highest')
>>> # -> will throw an error because int32 and uint32
>>> # cannot be promoted to one another.
>>> optax.tree_utils.tree_dtype(tree, 'promote')
dtype('int64')
Args:
tree: the tree to fetch the dtype of.
mixed_dtype_handler: how to handle mixed dtypes in the tree.
- If ``mixed_dtype_handler=None``, returns the common dtype of the leaves
of the tree if it exists, otherwise raises an error.
- If ``mixed_dtype_handler='promote'``, promotes the dtypes of the leaves
of the tree to a common promoted dtype using
:func:`jax.numpy.promote_types`.
- If ``mixed_dtype_handler='highest'`` or
``mixed_dtype_handler='lowest'``, returns the highest/lowest dtype of
the leaves of the tree. We consider a partial ordering of dtypes as
``dtype1 <= dtype2`` if ``dtype1`` is promoted to ``dtype2``, that is,
if ``jax.numpy.promote_types(dtype1, dtype2) == dtype2``. Since some
dtypes cannot be promoted to one another, this is not a total ordering,
and the 'highest' or 'lowest' options may not be applicable. These
options will throw an error if the dtypes of the leaves of the tree
cannot be promoted to one another.
Returns:
the dtype of the tree.
Raises:
ValueError: If ``mixed_dtype_handler`` is set to ``None`` and multiple
dtypes are found in the tree.
ValueError: If ``mixed_dtype_handler`` is set to ``'highest'`` or
``'lowest'`` and some leaves' dtypes in the tree cannot be promoted to one
another.
.. seealso:: :func:`jax.numpy.promote_types`,
`Type promotion semantics in JAX
<https://jax.readthedocs.io/en/latest/type_promotion.html#type-promotion>`_
.. versionadded:: 0.2.4
"""
leaves = jax.tree.leaves(tree)
if not leaves:
# If the tree is empty, we return the default dtype as given by JAX on
# empty lists.
return jnp.dtype(jnp.asarray(leaves))
if mixed_dtype_handler is None:
dtype = jnp.asarray(leaves[0]).dtype
_tree_assert_all_dtypes_equal(tree, dtype)
return dtype
elif mixed_dtype_handler == 'promote':
promoted_dtype = functools.reduce(
jnp.promote_types, [jnp.asarray(x).dtype for x in leaves]
)
return promoted_dtype
elif mixed_dtype_handler == 'highest':
highest_dtype = functools.reduce(
_higher_dtype, [jnp.asarray(x).dtype for x in leaves]
)
return highest_dtype
elif mixed_dtype_handler == 'lowest':
lowest_dtype = functools.reduce(
_lower_dtype, [jnp.asarray(x).dtype for x in leaves]
)
return lowest_dtype
else:
raise ValueError(
f'Invalid value for {mixed_dtype_handler=}, possible values are: None,'
' "promote", "highest", "lowest".'
)


def tree_assert_dtype_preserved(
tree: chex.ArrayTree,
dtype: chex.ArrayDType,
) -> None:
"""Checks whether some elements of tree may be promoted to dtype.
Some transformations like :func:`optax.scale_by_adam`, :func:`optax.trace`
allow the user to specify a dtype for some of the state's parameters (e.g. the
momentum term). This function checks that the specified dtype of the state's
parameters does not induce a dtype promotion of any of the parameters. That
way we can ensure that the dtype of the updates are consistent with the dtype
of the parameters.
Args:
tree: the tree to check.
dtype: the dtype to check against.
Raises:
ValueError: If any element of the tree is promoted to dtype.
.. versionadded:: 0.2.4
"""

def _assert_dtype_preserved(path, x):
x_dtype = jnp.asarray(x).dtype
if jnp.promote_types(x_dtype, dtype) != x_dtype:
err_msg = (
f'{dtype=} induces dtype promotion for {path} with dtype {x_dtype}.'
)
return err_msg

err_msgs = jax.tree.leaves(
jax.tree_util.tree_map_with_path(_assert_dtype_preserved, tree)
)
err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None]
if err_msgs:
raise ValueError('\n'.join(err_msgs))


def _tree_assert_all_dtypes_equal(
tree: chex.ArrayTree, dtype: chex.ArrayDType
) -> None:
"""Checks that all leaves of the tree have the given dtype.
Args:
tree: the tree to check.
dtype: the dtype to check against.
Raises:
ValueError: If any element of the tree does not match the given dtype.
"""

def _assert_dtypes_equal(path, x):
x_dtype = jnp.asarray(x).dtype
if x_dtype != dtype:
err_msg = f'Expected {dtype=} for {path} but got {x_dtype}.'
return err_msg

err_msgs = jax.tree.leaves(
jax.tree_util.tree_map_with_path(_assert_dtypes_equal, tree)
)
err_msgs = [err_msg for err_msg in err_msgs if err_msg is not None]
if err_msgs:
raise ValueError('\n'.join(err_msgs))


def _lower_dtype(
dtype1: chex.ArrayDType, dtype2: chex.ArrayDType
) -> chex.ArrayDType:
"""Returns lower dtype among two dtypes, if any can be promoted to the other.
Args:
dtype1: The first dtype to compare.
dtype2: The second dtype to compare.
Returns:
The lowest of the two dtypes, if any can be promoted to the other.
Raises:
ValueError: If none of the dtypes can be promoted to the other.
"""
if jnp.promote_types(dtype1, dtype2) == dtype1:
return dtype2
elif jnp.promote_types(dtype1, dtype2) == dtype2:
return dtype1
else:
raise ValueError(
f'Cannot compare dtype of {dtype1=} and {dtype2=}.'
f' Neither {dtype1} nor {dtype2} can be promoted to the other.'
)


def _higher_dtype(
dtype1: chex.ArrayDType, dtype2: chex.ArrayDType
) -> chex.ArrayDType:
"""Returns higher dtype among two dtypes, if any can be promoted to the other.
Args:
dtype1: The first dtype to compare.
dtype2: The second dtype to compare.
Returns:
The highest of the two dtypes, if any can be promoted to the other.
Raises:
ValueError: If none of the dtypes can be promoted to the other.
"""
if _lower_dtype(dtype1, dtype2) == dtype1:
return dtype2
else:
return dtype1
Loading

0 comments on commit 11cc0ea

Please sign in to comment.