Skip to content

Commit

Permalink
Add logical axis global context support for NNX
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Oct 31, 2024
1 parent 917c097 commit 492a72a
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 50 deletions.
59 changes: 59 additions & 0 deletions flax/core/spmd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import contextlib
import dataclasses
import threading

from flax.typing import (
LogicalRules,
Sharding,
)

# Dynamic Axis Mapping Context
# ------------------------------------------------------------------------------


@dataclasses.dataclass
class _AxisRules(threading.local):
"""Dynamic logical axis to mesh axis binding context."""

rules: LogicalRules = ()


# Global axis binding context.
_axis_rules = _AxisRules()


def set_logical_axis_rules(rules: LogicalRules):
"""Sets the global logical axis to mesh axis binding."""
_axis_rules.rules = rules


def get_logical_axis_rules() -> LogicalRules:
"""Returns the global logical axis to mesh axis binding."""
return _axis_rules.rules


@contextlib.contextmanager
def logical_axis_rules(rules: LogicalRules):
"""Context manager for setting the logical to mesh axis bindings."""
old_rules = _axis_rules.rules
try:
_axis_rules.rules = rules
yield
finally:
_axis_rules.rules = old_rules


def composite_rules(rule1, rule2):
if not rule1 and not rule2: return ()
rules = {alias: value for alias, value in rule1}
for alias, value in rule2:
if alias in rules and rules[alias] != value:
raise ValueError(f'Inconsistent logical axis annotations for {alias}: '
f'{rules[alias]} vs {value}')
rules[alias] = value
return tuple(rules.items())

def from_sharding_rules(sharding: Sharding,
sharding_rules: LogicalRules) -> Sharding:
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else None for s in sharding)
8 changes: 5 additions & 3 deletions flax/linen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
unbox as unbox,
with_partitioning as with_partitioning,
)
from flax.core.spmd import (
get_logical_axis_rules as get_logical_axis_rules,
logical_axis_rules as logical_axis_rules,
set_logical_axis_rules as set_logical_axis_rules,
)
from .activation import (
PReLU as PReLU,
celu as celu,
Expand Down Expand Up @@ -130,12 +135,9 @@
)
from .spmd import (
LogicallyPartitioned as LogicallyPartitioned,
get_logical_axis_rules as get_logical_axis_rules,
logical_axis_rules as logical_axis_rules,
logical_to_mesh,
logical_to_mesh_axes,
logical_to_mesh_sharding,
set_logical_axis_rules as set_logical_axis_rules,
with_logical_constraint,
with_logical_partitioning as with_logical_partitioning,
)
Expand Down
45 changes: 5 additions & 40 deletions flax/linen/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,9 @@
"""

import collections
import contextlib
import dataclasses
import enum
import functools
import threading
from typing import Any
from collections.abc import Callable, Sequence

Expand All @@ -39,6 +37,9 @@

from flax import struct
from flax.core import meta
from flax.core.spmd import (
get_logical_axis_rules,
)
from flax.typing import (
Array,
LogicalNames,
Expand All @@ -49,42 +50,6 @@
)


# Dynamic Axis Mapping Context
# ------------------------------------------------------------------------------


@dataclasses.dataclass
class _AxisRules(threading.local):
"""Dynamic logical axis to mesh axis binding context."""

rules: LogicalRules = ()


# Global axis binding context.
_axis_rules = _AxisRules()


def set_logical_axis_rules(rules: LogicalRules):
"""Sets the global logical axis to mesh axis binding."""
_axis_rules.rules = rules


def get_logical_axis_rules() -> LogicalRules:
"""Returns the global logical axis to mesh axis binding."""
return _axis_rules.rules


@contextlib.contextmanager
def logical_axis_rules(rules: LogicalRules):
"""Context manager for setting the logical to mesh axis bindings."""
old_rules = _axis_rules.rules
try:
_axis_rules.rules = rules
yield
finally:
_axis_rules.rules = old_rules


class _UnassignedAxis:
"""Sentinel class for unassigned logical axis name."""

Expand Down Expand Up @@ -115,7 +80,7 @@ def _logical_to_mesh_axes(
if array_dim_names is None:
return None
if rules is None:
rules = _axis_rules.rules
rules = get_logical_axis_rules()
axis_name_counts = collections.Counter(array_dim_names)
dups = tuple(
k for k, v in axis_name_counts.items() if v > 1 and k is not None
Expand Down Expand Up @@ -292,7 +257,7 @@ def with_logical_constraint(
"""Version of jit's with_sharding_constraint that uses logical axis names."""
# If no axis binding is set, this is a no-op.
if rules is None:
rules = _axis_rules.rules
rules = get_logical_axis_rules()
if not rules or logical_axis_resources is None:
return x
# Translate logical names to mesh assignments.
Expand Down
2 changes: 1 addition & 1 deletion flax/nnx/nn/recurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,7 @@ def num_feature_axes(self) -> int:
class RNN(Module):
"""The ``RNN`` module takes any :class:`RNNCellBase` instance and applies it over a sequence

using :func:`flax.linen.scan`.
using :func:`flax.nnx.scan`.
"""

def __init__(
Expand Down
13 changes: 7 additions & 6 deletions flax/nnx/spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jax.interpreters import pxla
from jax.sharding import PartitionSpec

import flax.core.spmd as core_spmd
from flax.nnx import variablelib
from flax.typing import (
Array,
Expand Down Expand Up @@ -89,15 +90,15 @@ def _maybe_replicate(x):
else:
return None

def from_rules(sharding, sharding_rules):
rules = {alias: on_mesh for (alias, on_mesh) in sharding_rules}
return (rules[s] if s in rules else None for s in sharding)

def f(x):
if isinstance(x, (variablelib.VariableState, variablelib.Variable)):
if hasattr(x, 'sharding') and x.sharding:
if hasattr(x, 'sharding_rules') and x.sharding_rules:
return x.replace(PartitionSpec(*from_rules(x.sharding, x.sharding_rules)))
if core_spmd.get_logical_axis_rules() or hasattr(x, 'sharding_rules'):
context_rules = core_spmd.get_logical_axis_rules()
local_rules = getattr(x, 'sharding_rules', ())
rules = core_spmd.composite_rules(context_rules, local_rules)
return x.replace(PartitionSpec(
*core_spmd.from_sharding_rules(x.sharding, rules)))
return x.replace(PartitionSpec(*x.sharding))
else:
return x.replace(_maybe_replicate(x.value))
Expand Down
33 changes: 33 additions & 0 deletions tests/nnx/spmd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from jax.experimental import mesh_utils
from jax.sharding import Mesh, PartitionSpec

import flax
from flax import nnx


Expand Down Expand Up @@ -158,6 +159,38 @@ def __call__(self, x: jax.Array):
self.assertEqual(badds, [(0, 'layers'), (0, 'layers')])
self.assertEqual(bremoves, [(0, 'layers')])

def test_logical_rules(self):
class Foo(nnx.Module):
def __init__(self):
self.w = nnx.Param(
nnx.with_partitioning(
lambda: jnp.ones((8, 2)),
sharding=('row-alias', 'col-alias'),
sharding_rules=(('row-alias', 'row'),)
)()
)
self.b = nnx.Param(
nnx.with_partitioning(
lambda: jnp.zeros((2, )),
sharding=('col-alias',)
)()
)

def __call__(self, x):
return x @ self.w + self.b

graphdef, params = nnx.split(Foo())
state = nnx.TrainState.create(
graphdef,
params=params,
tx=optax.adam(1e-3),
)
with flax.core.spmd.logical_axis_rules((('col-alias', 'col'),)):
state_spec = nnx.get_partition_spec(state)

assert state_spec.params['w'].value == PartitionSpec('row', 'col')
assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col')
assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col')

if __name__ == '__main__':
absltest.main()
Expand Down

0 comments on commit 492a72a

Please sign in to comment.