-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implementation of
jax_scalify.tree
sub-module. (#117)
PyTree methods adapted to `ScaledArray`: `all`, `flatten`, `leaves`, `map`, `structure`, `unflatten`, in `jax_scalify.tree`. Additionally, implementing `astype` as quite useful method on PyTrees!
- Loading branch information
Showing
4 changed files
with
179 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
from .tree_util import all, astype, flatten, leaves, map, structure, unflatten # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
from typing import Any, Callable | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import tree_util | ||
|
||
from jax_scalify.core import DTypeLike, is_scaled_leaf | ||
|
||
Leaf = Any | ||
|
||
|
||
def astype(tree: Any, dtype: DTypeLike, floating_only: bool = False) -> Any: | ||
"""Map `astype` method to all pytree leaves, `Array` or `ScaledArray`. | ||
Args: | ||
tree: the pytree to cast. | ||
dtype: Dtype to cast to. | ||
floating_only: Only convert leaves with floating datatype. | ||
Returns: | ||
A new PyTree with the same structure, with casting to new dtype. | ||
""" | ||
if floating_only: | ||
# Convert only leaves with floating dtype. | ||
cast_fn = lambda v: v.astype(dtype) if jnp.issubdtype(v.dtype, jnp.floating) else v | ||
return tree_util.tree_map(cast_fn, tree, is_leaf=is_scaled_leaf) | ||
return tree_util.tree_map(lambda v: v.astype(dtype), tree, is_leaf=is_scaled_leaf) | ||
|
||
|
||
def all(tree: Any) -> bool: | ||
"""Call all() over the leaves of a tree, `Array` or `ScaledArray` | ||
Args: | ||
tree: the pytree to evaluate | ||
Returns: | ||
result: boolean True or False | ||
""" | ||
return all(jax.tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf)) | ||
|
||
|
||
def flatten(tree: Any) -> tuple[list[Leaf], tree_util.PyTreeDef]: | ||
"""Flattens a pytree, with `Array` or `ScaledArray` leaves. | ||
The flattening order (i.e. the order of elements in the output list) | ||
is deterministic, corresponding to a left-to-right depth-first tree | ||
traversal. | ||
Args: | ||
tree: a pytree to flatten. | ||
Returns: | ||
A pair where the first element is a list of leaf values and the second | ||
element is a treedef representing the structure of the flattened tree. | ||
See Also: | ||
- :func:`jax_scalify.tree.leaves` | ||
- :func:`jax_scalify.tree.structure` | ||
- :func:`jax_scalify.tree.unflatten` | ||
""" | ||
return tree_util.tree_flatten(tree, is_leaf=is_scaled_leaf) | ||
|
||
|
||
def leaves( | ||
tree: Any, | ||
) -> list[Leaf]: | ||
"""Gets the leaves (`Array` or `ScaledArray`) of a pytree. | ||
Args: | ||
tree: the pytree for which to get the leaves | ||
Returns: | ||
leaves: a list of tree leaves. | ||
See Also: | ||
- :func:`jax_scalify.tree.flatten` | ||
- :func:`jax_scalify.tree.structure` | ||
- :func:`jax_scalify.tree.unflatten` | ||
""" | ||
return tree_util.tree_leaves(tree, is_leaf=is_scaled_leaf) | ||
|
||
|
||
def map(f: Callable[..., Any], tree: Any, *rest: Any) -> Any: | ||
"""Maps a multi-input function over pytree args to produce a new pytree. | ||
Args: | ||
f: function that takes ``1 + len(rest)`` arguments, to be applied at the | ||
corresponding leaves of the pytrees. | ||
tree: a pytree to be mapped over, with each leaf providing the first | ||
positional argument to ``f``. | ||
rest: a tuple of pytrees, each of which has the same structure as ``tree`` | ||
or has ``tree`` as a prefix. | ||
Returns: | ||
A new pytree with the same structure as ``tree`` but with the value at each | ||
leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding | ||
leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in | ||
``rest``. | ||
See Also: | ||
- :func:`jax_scalify.tree.leaves` | ||
- :func:`jax_scalify.tree.reduce` | ||
""" | ||
return tree_util.tree_map(f, tree, *rest, is_leaf=is_scaled_leaf) | ||
|
||
|
||
def structure(tree: Any) -> tree_util.PyTreeDef: | ||
"""Gets the treedef for a pytree, with `Array` or `ScaledArray` leaves. | ||
Args: | ||
tree: the pytree for which to get the leaves | ||
Returns: | ||
pytreedef: a PyTreeDef representing the structure of the tree. | ||
See Also: | ||
- :func:`jax_scalify.tree.flatten` | ||
- :func:`jax_scalify.tree.leaves` | ||
- :func:`jax_scalify.tree.unflatten` | ||
""" | ||
return tree_util.tree_structure(tree, is_leaf=is_scaled_leaf) | ||
|
||
|
||
# Alias of JAX tree unflatten. | ||
unflatten = jax.tree_util.tree_unflatten |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# Copyright (c) 2024 Graphcore Ltd. All rights reserved. | ||
import chex | ||
import numpy as np | ||
|
||
import jax_scalify as jsa | ||
|
||
|
||
class ScalifyTreeUtilTests(chex.TestCase): | ||
def test__tree_flatten__proper_result(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs, _ = jsa.tree.flatten(values) | ||
assert len(outputs) == 2 | ||
assert outputs[0] == 2 | ||
assert isinstance(outputs[1], jsa.ScaledArray) | ||
assert np.asarray(outputs[1]) == 1.5 | ||
|
||
def test__tree_leaves__proper_result(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs = jsa.tree.leaves(values) | ||
assert len(outputs) == 2 | ||
assert outputs[0] == 2 | ||
assert isinstance(outputs[1], jsa.ScaledArray) | ||
assert np.asarray(outputs[1]) == 1.5 | ||
|
||
def test__tree_structure__proper_result(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
pytree = jsa.tree.structure(values) | ||
assert pytree == jsa.tree.flatten(values)[1] | ||
|
||
def test__tree_unflatten__proper_result(self): | ||
values_in = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs, pytree = jsa.tree.flatten(values_in) | ||
values_out = jsa.tree.unflatten(pytree, outputs) | ||
assert values_out == values_in | ||
|
||
def test__tree_map__proper_result(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs = jsa.tree.map(lambda v: v.dtype, values) | ||
assert outputs == {"a": np.int32, "b": np.float32} | ||
|
||
def test__tree_astype__all_leaves_casting(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs = jsa.tree.astype(values, dtype=np.float16) | ||
dtypes = jsa.tree.map(lambda v: v.dtype, outputs) | ||
assert dtypes == {"a": np.float16, "b": np.float16} | ||
|
||
def test__tree_astype__only_float_casting(self): | ||
values = {"a": np.int32(2), "b": jsa.as_scaled_array(np.float32(1.5), 1.0)} | ||
outputs = jsa.tree.astype(values, dtype=np.float16, floating_only=True) | ||
dtypes = jsa.tree.map(lambda v: v.dtype, outputs) | ||
assert dtypes == {"a": np.int32, "b": np.float16} |