Skip to content

Commit

Permalink
Merge pull request #4191 from google:nnx-optimize-jit
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 674457375
  • Loading branch information
Flax Authors committed Sep 13, 2024
2 parents b967964 + 2079f15 commit d111adf
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 13 deletions.
18 changes: 10 additions & 8 deletions flax/nnx/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from flax import struct
from flax.nnx.object import Object
from flax.typing import MISSING, PathParts
from flax.typing import Missing, PathParts
from flax.nnx import graph


Expand Down Expand Up @@ -59,7 +59,7 @@ def extract_graph_nodes(
pytree: A,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
validate_fn: tp.Callable[[KeyPath, Prefix, Leaf], None] | None = None,
) -> (
tuple[A, tuple[tp.Any, ...]]
Expand Down Expand Up @@ -101,7 +101,7 @@ def extract_graph_nodes(

pytree_out = jax.tree.unflatten(treedef, leaves)

if prefix is MISSING:
if prefix is Missing:
return pytree_out, tuple(nodes) # type: ignore[bad-return-type]
else:
return pytree_out, tuple(nodes), tuple(node_prefixes) # type: ignore[bad-return-type]
Expand Down Expand Up @@ -330,12 +330,13 @@ def to_tree(
tree,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
split_fn: tp.Callable[
[graph.SplitContext, KeyPath, Prefix, Leaf], tp.Any
] = default_split_fn,
map_non_graph_nodes: bool = False,
ctxtag: str | None = None,
check_aliasing: bool = True,
) -> tp.Any:
leaf_prefixes = broadcast_prefix(
prefix,
Expand All @@ -351,9 +352,10 @@ def to_tree(
with graph.split_context(ctxtag) as split_ctx:
for (keypath, leaf), leaf_prefix in zip(leaf_keys, leaf_prefixes):
if graph.is_graph_node(leaf):
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
if check_aliasing:
check_consistent_aliasing(
leaf, leaf_prefix, node_prefixes=node_prefixes
)
tree_node = split_fn(split_ctx, keypath, leaf_prefix, leaf)
leaves_out.append(tree_node)
else:
Expand Down Expand Up @@ -381,7 +383,7 @@ def from_tree(
tree: tp.Any,
/,
*,
prefix: tp.Any = MISSING,
prefix: tp.Any = Missing,
merge_fn: tp.Callable[
[graph.MergeContext, KeyPath, Prefix, Leaf], tp.Any
] = merge_tree_node,
Expand Down
1 change: 1 addition & 0 deletions flax/nnx/transforms/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def jit_wrapper(*args, **kwargs):
(args, kwargs),
prefix=(in_shardings, kwarg_shardings),
split_fn=_jit_split_fn,
check_aliasing=in_shardings is not None,
ctxtag='jit',
)
pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
Expand Down
11 changes: 6 additions & 5 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit d111adf

Please sign in to comment.