Skip to content

Commit

Permalink
Support shard_map with sharding in types. Right now only full manual …
Browse files Browse the repository at this point in the history
…mode is supported.

This change also adds AxisTypes to Mesh which are `User`, `Auto` and `Collective`.

In the following changes, I'll remove the `config.sharding_in_types` flag and we'll enter into various modes via AxisTypes mentioned on the mesh.

PiperOrigin-RevId: 693048291
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Nov 14, 2024
1 parent 8370082 commit d805875
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 38 deletions.
22 changes: 21 additions & 1 deletion jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1657,7 +1657,10 @@ def str_short(self, short_dtypes=False):
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec))
return f'{dt_str}[{shapestr}]'
axis_types = self.sharding.mesh.axis_types
axt = (f"{{{', '.join(_get_axis_type_str(axis_types))}}}"
if axis_types is not None else '')
return f'{dt_str}[{shapestr}]{axt}'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
Expand All @@ -1669,6 +1672,23 @@ def _len(self, ignored_tracer):
raise TypeError("len() of unsized object") from err # same as numpy error


def _get_axis_type_str(axis_types):
from jax._src.mesh import AxisTypes # type: ignore

d = defaultdict(list)
for axis, t in axis_types.items():
d[t].append(axis)
for t, axes in d.items():
a = ','.join(a for a in axes)
a = a if len(a) == 1 else f"({a})"
if t == AxisTypes.Collective:
yield f"C:{a}"
elif t == AxisTypes.User:
yield f"U:{a}"
else:
assert t == AxisTypes.Auto
yield f"A:{a}"

def _get_shape_sharding_str(shape, spec):
for s1, s2 in zip(shape, spec):
if s2 is None:
Expand Down
14 changes: 6 additions & 8 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,14 +2203,13 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval):
for op, in_aval in zip(ops, in_avals):
if in_aval.sharding == out_aval.sharding or in_aval.sharding is None:
out.append(op)
elif in_aval.sharding.mesh.are_all_axis_types_collective:
out.append(op)
else:
# TODO(yashkatariya, dougalm): If `in_aval.sharding` contains
# CompilerShardingAxis, then specify `unspecified_dims` via
# `wrap_with_sharding_op`.
if config.use_shardy_partitioner.value:
sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim)
else:
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto()
out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp))
return out

Expand All @@ -2227,10 +2226,9 @@ def _nary_lower_hlo(op: Callable, ctx,

out = op(*args)
if config.sharding_in_types.value:
if config.use_shardy_partitioner.value:
out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim)
else:
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
if aval_out.sharding.mesh.are_all_axis_types_collective:
return [out]
out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)]
else:
return [out]
Expand Down
50 changes: 40 additions & 10 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import collections
from collections.abc import Hashable, Sequence
import contextlib
import enum
import functools
import math
import threading
Expand Down Expand Up @@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh:
return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names)


class AxisTypes(enum.Enum):
Auto = enum.auto()
User = enum.auto()
Collective = enum.auto()


_mesh_object_dict = {} # type: ignore


Expand Down Expand Up @@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator):

devices: np.ndarray
axis_names: tuple[MeshAxisName, ...]
axis_types: dict[str, AxisTypes] | None

def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName]):
axis_names: str | Sequence[MeshAxisName],
axis_types: dict[str, AxisTypes] | None = None):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
Expand All @@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
f"devices.ndim == {devices.ndim} and "
f"len(axis_names) == {len(axis_names)}.")

key = (axis_names, devices.shape, tuple(devices.flat))
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))
key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple)
val = _mesh_object_dict.get(key, None)
if val is not None:
return val
Expand All @@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
self.devices = devices.copy()
self.devices.flags.writeable = False
self.axis_names = axis_names
self.axis_types = axis_types
self._axis_types_tuple = axis_types_tuple
_mesh_object_dict[key] = self
return self

def __reduce__(self):
return (type(self), (self.devices, self.axis_names))
return (type(self), (self.devices, self.axis_names, self.axis_types))

def __eq__(self, other):
if not isinstance(other, Mesh):
Expand All @@ -199,12 +213,14 @@ def __eq__(self, other):
return True
return (self.axis_names == other.axis_names and
self.devices.shape == other.devices.shape and
self._axis_types_tuple == other._axis_types_tuple and
self._internal_device_list == other._internal_device_list)

def __hash__(self):
if not hasattr(self, '_hash'):
self._hash = hash(
(self.axis_names, self._internal_device_list, self.devices.shape))
(self.axis_names, self._internal_device_list, self.devices.shape,
self._axis_types_tuple))
return self._hash

def __setattr__(self, name, value):
Expand Down Expand Up @@ -301,7 +317,8 @@ def __str__(self):
def _repr(self):
if self.empty:
return "Mesh(device_ids=[], axis_names=())"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})"
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})"

def __repr__(self):
return self._repr
Expand All @@ -313,7 +330,7 @@ def local_devices(self):

@functools.cached_property
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple)
return AbstractMesh(self.shape_tuple, self.axis_types)


EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
Expand All @@ -338,25 +355,32 @@ class AbstractMesh:
details.
"""

def __init__(self, shape_tuple: tuple[tuple[str, int], ...]):
def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
axis_types: dict[str, AxisTypes] | None = None):
self.shape_tuple = shape_tuple
self.axis_types = axis_types
if self.shape_tuple:
self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple))
else:
self._axis_names, self._axis_sizes = (), ()
# TODO(yashkatariya): If axis_types is None, set all axes to AUTO.
self._axis_types_tuple = (None if axis_types is None else
tuple(axis_types.items()))

def __hash__(self):
return hash(self.shape_tuple)
return hash((self.shape_tuple, self._axis_types_tuple))

def __eq__(self, other):
if not isinstance(other, AbstractMesh):
return False
if id(self) == id(other):
return True
return self.shape_tuple == other.shape_tuple
return (self.shape_tuple == other.shape_tuple and
self._axis_types_tuple == other._axis_types_tuple)

def __repr__(self):
return f"AbstractMesh({self.shape_tuple})"
atr = '' if self.axis_types is None else f", axis_types={self.axis_types}"
return f"AbstractMesh({self.shape_tuple}{atr})"

@property
def axis_names(self):
Expand All @@ -382,6 +406,12 @@ def _internal_device_list(self):
def empty(self):
return self.size == 0

@functools.cached_property
def are_all_axis_types_collective(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.values())

@property
def devices(self):
_raise_value_error("devices")
Expand Down
20 changes: 20 additions & 0 deletions jax/_src/partition_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

class _UnconstrainedPartitionSingleton:

def __repr__(self):
Expand Down Expand Up @@ -48,3 +50,21 @@ def __repr__(self):

def __reduce__(self):
return (PartitionSpec, tuple(self))

def _normalized_spec(self, ndim: int) -> PartitionSpec:
out = [] # type: ignore
for p in self:
if p is None:
out.append(None)
elif p == self.UNCONSTRAINED:
out.append(p)
elif isinstance(p, (list, tuple)):
if len(p) == 1:
out.append(p[0])
else:
out.append(tuple(p))
else:
out.append(p)
if len(out) < ndim:
out.extend([None] * (ndim - len(out)))
return PartitionSpec(*out)
14 changes: 1 addition & 13 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,19 +361,7 @@ def with_memory_kind(self, kind: str) -> NamedSharding:
return NamedSharding(self.mesh, self.spec, memory_kind=kind)

def _normalized_spec(self, ndim: int) -> PartitionSpec:
out = [] # type: ignore
for p in self._parsed_pspec:
if p is None:
raise ValueError("UNCONSTRAINED is not supported yet.")
if not p:
out.append(None)
elif isinstance(p, tuple) and len(p) == 1:
out.append(p[0])
else:
out.append(p)
if len(out) < ndim:
out.extend([None] * (ndim - len(out)))
return PartitionSpec(*out)
return self.spec._normalized_spec(ndim)

def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
return named_sharding_to_xla_hlo_sharding(self, num_dimensions)
Expand Down
27 changes: 21 additions & 6 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
Expand Down Expand Up @@ -528,17 +528,32 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
raise NotImplementedError(f"Unsupported aval type: {type(aval)}")

def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
) -> core.AbstractValue:
) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
axis_names = mesh.axis_names
new_mesh = AbstractMesh(
mesh.shape_tuple,
dict(zip(axis_names, [AxisTypes.Collective] * len(axis_names))))
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)
core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array

def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
aval: core.AbstractValue,) -> core.AbstractValue:
assert isinstance(aval, core.ShapedArray)
return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape)))
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)
core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array

# Type-checking
Expand Down
24 changes: 24 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5201,6 +5201,30 @@ def f(x):
self.assertArraysEqual(out, np_inp)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None)))

def test_shard_map_full_manual(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)
arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y')))

def g(x, y):
self.assertTrue(x.sharding.mesh.are_all_axis_types_collective)
self.assertTrue(y.sharding.mesh.are_all_axis_types_collective)
return x * y

@jax.jit
def f(x, y):
z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec),
out_specs=P('x', 'y'))(x, y)
self.assertEqual(z.sharding.spec, P('x', 'y'))
out = z * 2
self.assertEqual(out.sharding.spec, P('x', 'y'))
return out

out = f(arr, arr2)
self.assertArraysEqual(out, (np_inp * np_inp) * 2)
self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y')))


@jtu.pytest_mark_if_available('multiaccelerator')
class PJitErrorTest(jtu.JaxTestCase):
Expand Down

0 comments on commit d805875

Please sign in to comment.