Skip to content

Commit

Permalink
Merge pull request #4194 from IvyZX:bdg-tree
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675627999
  • Loading branch information
Flax Authors committed Sep 17, 2024
2 parents 9eb0a61 + 4a04f75 commit ddaef57
Show file tree
Hide file tree
Showing 7 changed files with 477 additions and 439 deletions.
388 changes: 172 additions & 216 deletions docs_nnx/guides/bridge_guide.ipynb

Large diffs are not rendered by default.

264 changes: 93 additions & 171 deletions docs_nnx/guides/bridge_guide.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions flax/core/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_partition_spec(tree: Any) -> Any:
"""Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values."""

def f(x):
if isinstance(x, Partitioned):
if hasattr(x, 'get_partition_spec'):
return x.get_partition_spec()
# Unboxed arrays, which should be replicated across all devices
elif hasattr(x, 'shape'):
Expand All @@ -346,7 +346,7 @@ def f(x):
return None

return jax.tree_util.tree_map(
f, tree, is_leaf=lambda x: isinstance(x, Partitioned)
f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata)
)


Expand Down
74 changes: 72 additions & 2 deletions flax/nnx/bridge/variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from typing import Any, TypeVar

import jax
from flax import struct
from flax.core import meta
from flax.nnx import spmd
from flax.nnx import traversals
from flax.nnx import variables as variableslib
from flax.nnx.module import GraphDef
import typing as tp


Expand Down Expand Up @@ -105,6 +109,28 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]':
# TODO: implement this, supporting hooks
return self

def get_partition_spec(self) -> jax.sharding.PartitionSpec:
"""Returns the ``Partitionspec`` for this partitioned value."""
nnx_var = self.to_nnx_variable().to_state()
return spmd.get_partition_spec(nnx_var).value

def to_nnx_variable(self) -> variableslib.Variable:
return self.var_type(self.value, **self.metadata)


def is_vanilla_variable(vs: variableslib.VariableState) -> bool:
"""A variables state is vanilla if its metadata is essentially blank.
Returns False only if it has non-empty hooks or any non-built-in attribute.
"""
for key, value in vs.get_metadata().items():
if key.endswith('_hooks'):
if value != ():
return False
else:
return False
return True


def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
metadata = vs.get_metadata()
Expand All @@ -113,6 +139,8 @@ def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata:
if hasattr(linen_type, 'from_nnx_metadata'):
return linen_type.from_nnx_metadata({'value': vs.value, **metadata})
return linen_type(vs.value, **metadata)
if is_vanilla_variable(vs):
return vs.value
return NNXMeta(vs.type, vs.value, metadata)


Expand All @@ -128,11 +156,53 @@ def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable:
vtype = variable_type(col)
if isinstance(x, NNXMeta):
assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}'
return x.var_type(x.value, **x.metadata)
return x.to_nnx_variable()
if isinstance(x, meta.AxisMetadata):
x_metadata = vars(x)
if hasattr(x, 'to_nnx_metadata'):
x_metadata = x.to_nnx_metadata()
assert hasattr(x, 'value')
return vtype(**x_metadata, linen_meta_type=type(x))
return vtype(x)
return vtype(x)


def _recursive_merge(dict1, dict2):
"""Recursively merge two dicts."""
flat_map = traversals.flatten_mapping(dict1)
flat_map |= traversals.flatten_mapping(dict2)
return traversals.unflatten_mapping(flat_map)


def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]:
nnx_vars = jax.tree_util.tree_map_with_path(
lambda kp, x: to_nnx_var(get_col_name(kp), x),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
nnx_attrs: dict[str, Any] = defaultdict(dict)
for _, col_tree in nnx_vars.items():
assert isinstance(col_tree, dict)
for attr_name, value in col_tree.items():
assert isinstance(attr_name, str)
if isinstance(value, tp.Mapping): # it's a sublayer
nnx_attrs[attr_name] = _recursive_merge(nnx_attrs[attr_name], value)
else:
nnx_attrs[attr_name] = value # it's a variable on this layer
return nnx_attrs


def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict:
linen_structured = {}
for kp, v in traversals.flatten_mapping(
nnx_attrs,
is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef),
).items():
if isinstance(v, variableslib.Variable):
col_name = variable_type_name(type(v))
else:
col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule
linen_structured[(col_name, *kp)] = v
variables = traversals.unflatten_mapping(linen_structured)
variables = jax.tree.map(lambda x: to_linen_var(x.to_state()),
variables,
is_leaf=lambda x: isinstance(x, variableslib.Variable))
return variables

39 changes: 19 additions & 20 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ class ToNNX(Module):
>>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x)
>>> # Like Linen apply(), but using NNX's direct call method
>>> y = model(x)
>>> nnx.state(model).params.kernel.value.shape
>>> model.kernel.shape
(32, 64)
Args:
Expand All @@ -121,7 +121,7 @@ def __init__(
):
self.module = module
self.rngs = rngs
self.linen_collections: tuple[str, ...] = ()
self.linen_attributes: tuple[str, ...] = ()

def lazy_init(self, *args, **kwargs):
"""A shortcut of calling `nnx.bridge.lazy_init()` upon this module."""
Expand All @@ -146,20 +146,17 @@ def __call__(
_rngs['params'] = _rngs.pop('default')
out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs)

nnx_vars = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
linen_collections = set()
for col, tree in nnx_vars.items():
setattr(self, col, tree)
linen_collections.add(col)
self.linen_collections = tuple(linen_collections) # make it hashable
nnx_attrs = bv.linen_vars_to_nnx_attrs(variables)
linen_attributes = set()
for attr_name, value in nnx_attrs.items():
setattr(self, attr_name, value)
linen_attributes.add(attr_name)
self.linen_attributes = tuple(linen_attributes) # make it hashable

else:
variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()),
getattr(self, col),
is_leaf=lambda x: isinstance(x, nnx.Variable))
for col in self.linen_collections}
nnx_attrs = {name: getattr(self, name) for name in self.linen_attributes}
variables = bv.nnx_attrs_to_linen_vars(nnx_attrs)

_rngs = (
{name: stream() for name, stream in rngs.items()} if rngs else {}
)
Expand All @@ -168,11 +165,13 @@ def __call__(
# Split out the updates if `mutable` is passed into the Flax module
if kwargs.get('mutable', False) != False:
out, updates = out
updates = jtu.tree_map_with_path(
lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x),
updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata))
for collection, value in updates.items():
setattr(self, collection, value)
nnx_attrs = bv.linen_vars_to_nnx_attrs(updates)
for attr_name, value in nnx_attrs.items():
if hasattr(self, attr_name) and isinstance(value, dict):
original_tree = getattr(self, attr_name)
setattr(self, attr_name, original_tree | value)
else:
setattr(self, attr_name, value)

return out

Expand Down Expand Up @@ -202,7 +201,7 @@ class ToLinen(linen.Module):
>>> y, variables = model.init_with_output(jax.random.key(0), x)
>>> y.shape
(1, 64)
>>> variables['params']['kernel'].value.shape
>>> variables['params']['kernel'].shape
(32, 64)
>>> # The static GraphDef of the underlying NNX module
>>> variables.keys()
Expand Down
Loading

0 comments on commit ddaef57

Please sign in to comment.