From 2079f15b1688bed34593828a8a63b62306af5391 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 13 Sep 2024 09:07:17 +0000 Subject: [PATCH] split docs --- flax/nnx/extract.py | 18 ++++++++++-------- flax/nnx/transforms/compilation.py | 1 + uv.lock | 11 ++++++----- 3 files changed, 17 insertions(+), 13 deletions(-) diff --git a/flax/nnx/extract.py b/flax/nnx/extract.py index 6ecf6f2405..845544c307 100644 --- a/flax/nnx/extract.py +++ b/flax/nnx/extract.py @@ -22,7 +22,7 @@ from flax import struct from flax.nnx.object import Object -from flax.typing import MISSING, PathParts +from flax.typing import Missing, PathParts from flax.nnx import graph @@ -59,7 +59,7 @@ def extract_graph_nodes( pytree: A, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None, ) -> ( tuple[A, tuple[tp.Any, ...]] @@ -101,7 +101,7 @@ def extract_graph_nodes( pytree_out = jax.tree.unflatten(treedef, leaves) - if prefix is MISSING: + if prefix is Missing: return pytree_out, tuple(nodes) # type: ignore[bad-return-type] else: return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type] @@ -330,12 +330,13 @@ def to_tree( tree, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, split_fn: tp.Callable[ [graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any ] = default_split_fn, map_non_graph_nodes: bool = False, ctxtag: str | None = None, + check_aliasing: bool = True, ) -> tp.Any: leaf_prefixes = broadcast_prefix( prefix, @@ -351,9 +352,10 @@ def to_tree( with graph.split_context(ctxtag) as split_ctx: for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes): if graph.is_graph_node(leaf): - check_consistent_aliasing( - leaf, leaf_prefix, node_prefixes=node_prefixes - ) + if check_aliasing: + check_consistent_aliasing( + leaf, leaf_prefix, node_prefixes=node_prefixes + ) tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf) leaves_out.append(tree_node) else: @@ -381,7 +383,7 @@ def from_tree( tree: tp.Any, /, *, - prefix: tp.Any = MISSING, + prefix: tp.Any = Missing, merge_fn: tp.Callable[ [graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any ] = merge_tree_node, diff --git a/flax/nnx/transforms/compilation.py b/flax/nnx/transforms/compilation.py index d715898ce0..1f63654d63 100644 --- a/flax/nnx/transforms/compilation.py +++ b/flax/nnx/transforms/compilation.py @@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs): (args, kwargs), prefix=(in_shardings, kwarg_shardings), split_fn=_jit_split_fn, + check_aliasing=in_shardings is not None, ctxtag='jit', ) pure_args_out, pure_kwargs_out, pure_out = jitted_fn( diff --git a/uv.lock b/uv.lock index 5dbc9e8070..29d358e255 100644 --- a/uv.lock +++ b/uv.lock @@ -767,7 +767,7 @@ wheels = [ [[package]] name = "flax" -version = "0.8.6" +version = "0.9.0" source = { editable = "." } dependencies = [ { name = "jax" }, @@ -809,7 +809,9 @@ docs = [ testing = [ { name = "clu" }, { name = "einops" }, - { name = "gymnasium", extra = ["accept-rom-license", "atari"] }, + { name = "gymnasium" }, + { name = "gymnasium", extra = ["accept-rom-license"] }, + { name = "gymnasium", extra = ["atari"] }, { name = "jaxlib" }, { name = "jaxtyping" }, { name = "jraph" }, @@ -1044,9 +1046,11 @@ wheels = [ [package.optional-dependencies] accept-rom-license = [ + { name = "autorom" }, { name = "autorom", extra = ["accept-rom-license"] }, ] atari = [ + { name = "shimmy" }, { name = "shimmy", extra = ["atari"] }, ] @@ -3587,9 +3591,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 }, { url = "https://files.pythonhosted.org/packages/33/3e/a2f59384587eff6aeb7d37b6780de7fedd2214935e27520430ca9f5b7975/triton-3.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5ce8520437c602fb633f1324cc3871c47bee3b67acf9756c1a66309b60e3216c", size = 209438883 }, { url = "https://files.pythonhosted.org/packages/fe/7b/7757205dee3628f75e7991021d15cd1bd0c9b044ca9affe99b50879fc0e1/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:34e509deb77f1c067d8640725ef00c5cbfcb2052a1a3cb6a6d343841f92624eb", size = 209464695 }, - { url = "https://files.pythonhosted.org/packages/15/67/84e5a4b7b45bdeb11da26a67dfa2b988c512abbcbcad8cbc30aa579051b2/triton-3.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39b052da883351fdf6be3d93cedae6db3b8e3988d3b09ed221bccecfa9612230", size = 209380247 }, - { url = "https://files.pythonhosted.org/packages/ea/6b/1d72cc8a7379822dadf050474add7d8b73b02c35057446b6f17d27cb9ea2/triton-3.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cd34f19a8582af96e6291d4afce25dac08cb2a5d218c599163761e8e0827208e", size = 209442823 }, - { url = "https://files.pythonhosted.org/packages/ae/b2/048c9ecfdba0e6b0ae3c02eed2d9dd3e9e990a6d46da98555cf0c2232168/triton-3.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0d5e10de8c011adeb7c878c6ce0dd6073b14367749e34467f1cff2bde1b78253", size = 209468633 }, ] [[package]]