From 492a72a55d66a2682e6c52dc12baabbf87cf99ae Mon Sep 17 00:00:00 2001 From: IvyZX Date: Wed, 30 Oct 2024 17:07:12 -0700 Subject: [PATCH] Add logical axis global context support for NNX --- flax/core/spmd.py | 59 ++++++++++++++++++++++++++++++++++++++++ flax/linen/__init__.py | 8 ++++-- flax/linen/spmd.py | 45 ++++-------------------------- flax/nnx/nn/recurrent.py | 2 +- flax/nnx/spmd.py | 13 +++++---- tests/nnx/spmd_test.py | 33 ++++++++++++++++++++++ 6 files changed, 110 insertions(+), 50 deletions(-) create mode 100644 flax/core/spmd.py diff --git a/flax/core/spmd.py b/flax/core/spmd.py new file mode 100644 index 0000000000..46142f5559 --- /dev/null +++ b/flax/core/spmd.py @@ -0,0 +1,59 @@ +import contextlib +import dataclasses +import threading + +from flax.typing import ( + LogicalRules, + Sharding, +) + +# Dynamic Axis Mapping Context +# ------------------------------------------------------------------------------ + + +@dataclasses.dataclass +class _AxisRules(threading.local): + """Dynamic logical axis to mesh axis binding context.""" + + rules: LogicalRules = () + + +# Global axis binding context. +_axis_rules = _AxisRules() + + +def set_logical_axis_rules(rules: LogicalRules): + """Sets the global logical axis to mesh axis binding.""" + _axis_rules.rules = rules + + +def get_logical_axis_rules() -> LogicalRules: + """Returns the global logical axis to mesh axis binding.""" + return _axis_rules.rules + + +@contextlib.contextmanager +def logical_axis_rules(rules: LogicalRules): + """Context manager for setting the logical to mesh axis bindings.""" + old_rules = _axis_rules.rules + try: + _axis_rules.rules = rules + yield + finally: + _axis_rules.rules = old_rules + + +def composite_rules(rule1, rule2): + if not rule1 and not rule2: return () + rules = {alias: value for alias, value in rule1} + for alias, value in rule2: + if alias in rules and rules[alias] != value: + raise ValueError(f'Inconsistent logical axis annotations for {alias}: ' + f'{rules[alias]} vs {value}') + rules[alias] = value + return tuple(rules.items()) + +def from_sharding_rules(sharding: Sharding, + sharding_rules: LogicalRules) -> Sharding: + rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} + return (rules[s] if s in rules else None for s in sharding) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index ff4e384acd..6bb715667f 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -31,6 +31,11 @@ unbox as unbox, with_partitioning as with_partitioning, ) +from flax.core.spmd import ( + get_logical_axis_rules as get_logical_axis_rules, + logical_axis_rules as logical_axis_rules, + set_logical_axis_rules as set_logical_axis_rules, +) from .activation import ( PReLU as PReLU, celu as celu, @@ -130,12 +135,9 @@ ) from .spmd import ( LogicallyPartitioned as LogicallyPartitioned, - get_logical_axis_rules as get_logical_axis_rules, - logical_axis_rules as logical_axis_rules, logical_to_mesh, logical_to_mesh_axes, logical_to_mesh_sharding, - set_logical_axis_rules as set_logical_axis_rules, with_logical_constraint, with_logical_partitioning as with_logical_partitioning, ) diff --git a/flax/linen/spmd.py b/flax/linen/spmd.py index cd622bbdae..0c1945a097 100644 --- a/flax/linen/spmd.py +++ b/flax/linen/spmd.py @@ -25,11 +25,9 @@ """ import collections -import contextlib import dataclasses import enum import functools -import threading from typing import Any from collections.abc import Callable, Sequence @@ -39,6 +37,9 @@ from flax import struct from flax.core import meta +from flax.core.spmd import ( + get_logical_axis_rules, +) from flax.typing import ( Array, LogicalNames, @@ -49,42 +50,6 @@ ) -# Dynamic Axis Mapping Context -# ------------------------------------------------------------------------------ - - -@dataclasses.dataclass -class _AxisRules(threading.local): - """Dynamic logical axis to mesh axis binding context.""" - - rules: LogicalRules = () - - -# Global axis binding context. -_axis_rules = _AxisRules() - - -def set_logical_axis_rules(rules: LogicalRules): - """Sets the global logical axis to mesh axis binding.""" - _axis_rules.rules = rules - - -def get_logical_axis_rules() -> LogicalRules: - """Returns the global logical axis to mesh axis binding.""" - return _axis_rules.rules - - -@contextlib.contextmanager -def logical_axis_rules(rules: LogicalRules): - """Context manager for setting the logical to mesh axis bindings.""" - old_rules = _axis_rules.rules - try: - _axis_rules.rules = rules - yield - finally: - _axis_rules.rules = old_rules - - class _UnassignedAxis: """Sentinel class for unassigned logical axis name.""" @@ -115,7 +80,7 @@ def _logical_to_mesh_axes( if array_dim_names is None: return None if rules is None: - rules = _axis_rules.rules + rules = get_logical_axis_rules() axis_name_counts = collections.Counter(array_dim_names) dups = tuple( k for k, v in axis_name_counts.items() if v > 1 and k is not None @@ -292,7 +257,7 @@ def with_logical_constraint( """Version of jit's with_sharding_constraint that uses logical axis names.""" # If no axis binding is set, this is a no-op. if rules is None: - rules = _axis_rules.rules + rules = get_logical_axis_rules() if not rules or logical_axis_resources is None: return x # Translate logical names to mesh assignments. diff --git a/flax/nnx/nn/recurrent.py b/flax/nnx/nn/recurrent.py index 44b89ad979..8f750392b0 100644 --- a/flax/nnx/nn/recurrent.py +++ b/flax/nnx/nn/recurrent.py @@ -590,7 +590,7 @@ def num_feature_axes(self) -> int: class RNN(Module): """The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence - using :func:`flax.linen.scan`. + using :func:`flax.nnx.scan`. """ def __init__( diff --git a/flax/nnx/spmd.py b/flax/nnx/spmd.py index fd9deb89f8..2f37921915 100644 --- a/flax/nnx/spmd.py +++ b/flax/nnx/spmd.py @@ -19,6 +19,7 @@ from jax.interpreters import pxla from jax.sharding import PartitionSpec +import flax.core.spmd as core_spmd from flax.nnx import variablelib from flax.typing import ( Array, @@ -89,15 +90,15 @@ def _maybe_replicate(x): else: return None - def from_rules(sharding, sharding_rules): - rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules} - return (rules[s] if s in rules else None for s in sharding) - def f(x): if isinstance(x, (variablelib.VariableState, variablelib.Variable)): if hasattr(x, 'sharding') and x.sharding: - if hasattr(x, 'sharding_rules') and x.sharding_rules: - return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules))) + if core_spmd.get_logical_axis_rules() or hasattr(x, 'sharding_rules'): + context_rules = core_spmd.get_logical_axis_rules() + local_rules = getattr(x, 'sharding_rules', ()) + rules = core_spmd.composite_rules(context_rules, local_rules) + return x.replace(PartitionSpec( + *core_spmd.from_sharding_rules(x.sharding, rules))) return x.replace(PartitionSpec(*x.sharding)) else: return x.replace(_maybe_replicate(x.value)) diff --git a/tests/nnx/spmd_test.py b/tests/nnx/spmd_test.py index 6a202e8135..046798e4e8 100644 --- a/tests/nnx/spmd_test.py +++ b/tests/nnx/spmd_test.py @@ -19,6 +19,7 @@ from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec +import flax from flax import nnx @@ -158,6 +159,38 @@ def __call__(self, x: jax.Array): self.assertEqual(badds, [(0, 'layers'), (0, 'layers')]) self.assertEqual(bremoves, [(0, 'layers')]) + def test_logical_rules(self): + class Foo(nnx.Module): + def __init__(self): + self.w = nnx.Param( + nnx.with_partitioning( + lambda: jnp.ones((8, 2)), + sharding=('row-alias', 'col-alias'), + sharding_rules=(('row-alias', 'row'),) + )() + ) + self.b = nnx.Param( + nnx.with_partitioning( + lambda: jnp.zeros((2, )), + sharding=('col-alias',) + )() + ) + + def __call__(self, x): + return x @ self.w + self.b + + graphdef, params = nnx.split(Foo()) + state = nnx.TrainState.create( + graphdef, + params=params, + tx=optax.adam(1e-3), + ) + with flax.core.spmd.logical_axis_rules((('col-alias', 'col'),)): + state_spec = nnx.get_partition_spec(state) + + assert state_spec.params['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') if __name__ == '__main__': absltest.main()