diff --git a/jax_scalify/__init__.py b/jax_scalify/__init__.py index 838fd2c..32988d7 100644 --- a/jax_scalify/__init__.py +++ b/jax_scalify/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from . import core, lax, ops +from . import core, lax, ops, tree from ._version import __version__ from .core import ( # noqa: F401 Pow2RoundMode, diff --git a/jax_scalify/tree/__init__.py b/jax_scalify/tree/__init__.py new file mode 100644 index 0000000..27b6b52 --- /dev/null +++ b/jax_scalify/tree/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +from .tree_util import all, astype, flatten, leaves, map, structure, unflatten # noqa: F401 diff --git a/jax_scalify/tree/tree_util.py b/jax_scalify/tree/tree_util.py new file mode 100644 index 0000000..45f0315 --- /dev/null +++ b/jax_scalify/tree/tree_util.py @@ -0,0 +1,125 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +from typing import Any, Callable + +import jax +import jax.numpy as jnp +from jax import tree_util + +from jax_scalify.core import DTypeLike, is_scaled_leaf + +Leaf = Any + + +def astype(tree: Any, dtype: DTypeLike, floating_only: bool = False) -> Any: + """Map `astype` method to all pytree leaves, `Array` or `ScaledArray`. + + Args: + tree: the pytree to cast. + dtype: Dtype to cast to. + floating_only: Only convert leaves with floating datatype. + + Returns: + A new PyTree with the same structure, with casting to new dtype. + """ + if floating_only: + # Convert only leaves with floating dtype. + cast_fn = lambda v: v.astype(dtype) if jnp.issubdtype(v.dtype, jnp.floating) else v + return tree_util.tree_map(cast_fn, tree, is_leaf=is_scaled_leaf) + return tree_util.tree_map(lambda v: v.astype(dtype), tree, is_leaf=is_scaled_leaf) + + +def all(tree: Any) -> bool: + """Call all() over the leaves of a tree, `Array` or `ScaledArray` + + Args: + tree: the pytree to evaluate + Returns: + result: boolean True or False + """ + return all(jax.tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf)) + + +def flatten(tree: Any) -> tuple[list[Leaf], tree_util.PyTreeDef]: + """Flattens a pytree, with `Array` or `ScaledArray` leaves. + + The flattening order (i.e. the order of elements in the output list) + is deterministic, corresponding to a left-to-right depth-first tree + traversal. + + Args: + tree: a pytree to flatten. + + Returns: + A pair where the first element is a list of leaf values and the second + element is a treedef representing the structure of the flattened tree. + + See Also: + - :func:`jax_scalify.tree.leaves` + - :func:`jax_scalify.tree.structure` + - :func:`jax_scalify.tree.unflatten` + """ + return tree_util.tree_flatten(tree, is_leaf=is_scaled_leaf) + + +def leaves( + tree: Any, +) -> list[Leaf]: + """Gets the leaves (`Array` or `ScaledArray`) of a pytree. + + Args: + tree: the pytree for which to get the leaves + + Returns: + leaves: a list of tree leaves. + + See Also: + - :func:`jax_scalify.tree.flatten` + - :func:`jax_scalify.tree.structure` + - :func:`jax_scalify.tree.unflatten` + """ + return tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf) + + +def map(f: Callable[..., Any], tree: Any, *rest: Any) -> Any: + """Maps a multi-input function over pytree args to produce a new pytree. + + Args: + f: function that takes ``1 + len(rest)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree: a pytree to be mapped over, with each leaf providing the first + positional argument to ``f``. + rest: a tuple of pytrees, each of which has the same structure as ``tree`` + or has ``tree`` as a prefix. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each + leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding + leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in + ``rest``. + + See Also: + - :func:`jax_scalify.tree.leaves` + - :func:`jax_scalify.tree.reduce` + """ + return tree_util.tree_map(f, tree, *rest, is_leaf=is_scaled_leaf) + + +def structure(tree: Any) -> tree_util.PyTreeDef: + """Gets the treedef for a pytree, with `Array` or `ScaledArray` leaves. + + Args: + tree: the pytree for which to get the leaves + + Returns: + pytreedef: a PyTreeDef representing the structure of the tree. + + See Also: + - :func:`jax_scalify.tree.flatten` + - :func:`jax_scalify.tree.leaves` + - :func:`jax_scalify.tree.unflatten` + """ + return tree_util.tree_structure(tree, is_leaf=is_scaled_leaf) + + +# Alias of JAX tree unflatten. +unflatten = jax.tree_util.tree_unflatten diff --git a/tests/tree/test_tree_util.py b/tests/tree/test_tree_util.py new file mode 100644 index 0000000..9f391ef --- /dev/null +++ b/tests/tree/test_tree_util.py @@ -0,0 +1,51 @@ +# Copyright (c) 2024 Graphcore Ltd. All rights reserved. +import chex +import numpy as np + +import jax_scalify as jsa + + +class ScalifyTreeUtilTests(chex.TestCase): + def test__tree_flatten__proper_result(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs, _ = jsa.tree.flatten(values) + assert len(outputs) == 2 + assert outputs[0] == 2 + assert isinstance(outputs[1], jsa.ScaledArray) + assert np.asarray(outputs[1]) == 1.5 + + def test__tree_leaves__proper_result(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs = jsa.tree.leaves(values) + assert len(outputs) == 2 + assert outputs[0] == 2 + assert isinstance(outputs[1], jsa.ScaledArray) + assert np.asarray(outputs[1]) == 1.5 + + def test__tree_structure__proper_result(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + pytree = jsa.tree.structure(values) + assert pytree == jsa.tree.flatten(values)[1] + + def test__tree_unflatten__proper_result(self): + values_in = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs, pytree = jsa.tree.flatten(values_in) + values_out = jsa.tree.unflatten(pytree, outputs) + assert values_out == values_in + + def test__tree_map__proper_result(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs = jsa.tree.map(lambda v: v.dtype, values) + assert outputs == {"a": np.int32, "b": np.float32} + + def test__tree_astype__all_leaves_casting(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs = jsa.tree.astype(values, dtype=np.float16) + dtypes = jsa.tree.map(lambda v: v.dtype, outputs) + assert dtypes == {"a": np.float16, "b": np.float16} + + def test__tree_astype__only_float_casting(self): + values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} + outputs = jsa.tree.astype(values, dtype=np.float16, floating_only=True) + dtypes = jsa.tree.map(lambda v: v.dtype, outputs) + assert dtypes == {"a": np.int32, "b": np.float16}