diff --git a/jax/_src/core.py b/jax/_src/core.py index 96aecfde3a74..c1253a4d6431 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1657,7 +1657,10 @@ def str_short(self, short_dtypes=False): dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: shapestr = ','.join(_get_shape_sharding_str(self.shape, self.sharding.spec)) - return f'{dt_str}[{shapestr}]' + axis_types = self.sharding.mesh.axis_types + axt = (f"{{{', '.join(_get_axis_type_str(axis_types))}}}" + if axis_types is not None else '') + return f'{dt_str}[{shapestr}]{axt}' else: shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' @@ -1669,6 +1672,23 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error +def _get_axis_type_str(axis_types): + from jax._src.mesh import AxisTypes # type: ignore + + d = defaultdict(list) + for axis, t in axis_types.items(): + d[t].append(axis) + for t, axes in d.items(): + a = ','.join(a for a in axes) + a = a if len(a) == 1 else f"({a})" + if t == AxisTypes.Collective: + yield f"C:{a}" + elif t == AxisTypes.User: + yield f"U:{a}" + else: + assert t == AxisTypes.Auto + yield f"A:{a}" + def _get_shape_sharding_str(shape, spec): for s1, s2 in zip(shape, spec): if s2 is None: diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 7e15f46c3ef1..88c61de03fa4 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2203,14 +2203,13 @@ def multi_sharding_in_dim(ctx, ops, in_avals, out_aval): for op, in_aval in zip(ops, in_avals): if in_aval.sharding == out_aval.sharding or in_aval.sharding is None: out.append(op) + elif in_aval.sharding.mesh.are_all_axis_types_collective: + out.append(op) else: # TODO(yashkatariya, dougalm): If `in_aval.sharding` contains # CompilerShardingAxis, then specify `unspecified_dims` via # `wrap_with_sharding_op`. - if config.use_shardy_partitioner.value: - sp = in_aval.sharding._to_sdy_sharding(in_aval.ndim) - else: - sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() + sp = in_aval.sharding._to_xla_hlo_sharding(in_aval.ndim).to_proto() out.append(mlir.wrap_with_sharding_op(ctx, op, out_aval, sp)) return out @@ -2227,10 +2226,9 @@ def _nary_lower_hlo(op: Callable, ctx, out = op(*args) if config.sharding_in_types.value: - if config.use_shardy_partitioner.value: - out_sp = aval_out.sharding._to_sdy_sharding(aval_out.ndim) - else: - out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() + if aval_out.sharding.mesh.are_all_axis_types_collective: + return [out] + out_sp = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto() return [mlir.wrap_with_sharding_op(ctx, out, aval_out, out_sp)] else: return [out] diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index 43791f2e5f72..f734ef62d86c 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -18,6 +18,7 @@ import collections from collections.abc import Hashable, Sequence import contextlib +import enum import functools import math import threading @@ -101,6 +102,12 @@ def _get_local_mesh(global_mesh: Mesh, process_index: int) -> Mesh: return Mesh(global_mesh.devices[subcube_indices_tuple], global_mesh.axis_names) +class AxisTypes(enum.Enum): + Auto = enum.auto() + User = enum.auto() + Collective = enum.auto() + + _mesh_object_dict = {} # type: ignore @@ -157,9 +164,11 @@ class Mesh(contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] + axis_types: dict[str, AxisTypes] | None def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName]): + axis_names: str | Sequence[MeshAxisName], + axis_types: dict[str, AxisTypes] | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -175,7 +184,10 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], f"devices.ndim == {devices.ndim} and " f"len(axis_names) == {len(axis_names)}.") - key = (axis_names, devices.shape, tuple(devices.flat)) + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) + key = (axis_names, devices.shape, tuple(devices.flat), axis_types_tuple) val = _mesh_object_dict.get(key, None) if val is not None: return val @@ -184,11 +196,13 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], self.devices = devices.copy() self.devices.flags.writeable = False self.axis_names = axis_names + self.axis_types = axis_types + self._axis_types_tuple = axis_types_tuple _mesh_object_dict[key] = self return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names)) + return (type(self), (self.devices, self.axis_names, self.axis_types)) def __eq__(self, other): if not isinstance(other, Mesh): @@ -199,12 +213,14 @@ def __eq__(self, other): return True return (self.axis_names == other.axis_names and self.devices.shape == other.devices.shape and + self._axis_types_tuple == other._axis_types_tuple and self._internal_device_list == other._internal_device_list) def __hash__(self): if not hasattr(self, '_hash'): self._hash = hash( - (self.axis_names, self._internal_device_list, self.devices.shape)) + (self.axis_names, self._internal_device_list, self.devices.shape, + self._axis_types_tuple)) return self._hash def __setattr__(self, name, value): @@ -301,7 +317,8 @@ def __str__(self): def _repr(self): if self.empty: return "Mesh(device_ids=[], axis_names=())" - return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"Mesh(device_ids={self.device_ids!r}, axis_names={self.axis_names!r}{atr})" def __repr__(self): return self._repr @@ -313,7 +330,7 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple) + return AbstractMesh(self.shape_tuple, self.axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -338,25 +355,32 @@ class AbstractMesh: details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...]): + def __init__(self, shape_tuple: tuple[tuple[str, int], ...], + axis_types: dict[str, AxisTypes] | None = None): self.shape_tuple = shape_tuple + self.axis_types = axis_types if self.shape_tuple: self._axis_names, self._axis_sizes = list(zip(*self.shape_tuple)) else: self._axis_names, self._axis_sizes = (), () + # TODO(yashkatariya): If axis_types is None, set all axes to AUTO. + self._axis_types_tuple = (None if axis_types is None else + tuple(axis_types.items())) def __hash__(self): - return hash(self.shape_tuple) + return hash((self.shape_tuple, self._axis_types_tuple)) def __eq__(self, other): if not isinstance(other, AbstractMesh): return False if id(self) == id(other): return True - return self.shape_tuple == other.shape_tuple + return (self.shape_tuple == other.shape_tuple and + self._axis_types_tuple == other._axis_types_tuple) def __repr__(self): - return f"AbstractMesh({self.shape_tuple})" + atr = '' if self.axis_types is None else f", axis_types={self.axis_types}" + return f"AbstractMesh({self.shape_tuple}{atr})" @property def axis_names(self): @@ -382,6 +406,12 @@ def _internal_device_list(self): def empty(self): return self.size == 0 + @functools.cached_property + def are_all_axis_types_collective(self) -> bool: + if self.axis_types is None: + return False + return all(t == AxisTypes.Collective for t in self.axis_types.values()) + @property def devices(self): _raise_value_error("devices") diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index 18e7d18d931d..f9bc2b60cee9 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + class _UnconstrainedPartitionSingleton: def __repr__(self): @@ -48,3 +50,21 @@ def __repr__(self): def __reduce__(self): return (PartitionSpec, tuple(self)) + + def _normalized_spec(self, ndim: int) -> PartitionSpec: + out = [] # type: ignore + for p in self: + if p is None: + out.append(None) + elif p == self.UNCONSTRAINED: + out.append(p) + elif isinstance(p, (list, tuple)): + if len(p) == 1: + out.append(p[0]) + else: + out.append(tuple(p)) + else: + out.append(p) + if len(out) < ndim: + out.extend([None] * (ndim - len(out))) + return PartitionSpec(*out) diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index fa65bbe9328d..9b847f15d86a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -361,19 +361,7 @@ def with_memory_kind(self, kind: str) -> NamedSharding: return NamedSharding(self.mesh, self.spec, memory_kind=kind) def _normalized_spec(self, ndim: int) -> PartitionSpec: - out = [] # type: ignore - for p in self._parsed_pspec: - if p is None: - raise ValueError("UNCONSTRAINED is not supported yet.") - if not p: - out.append(None) - elif isinstance(p, tuple) and len(p) == 1: - out.append(p[0]) - else: - out.append(p) - if len(out) < ndim: - out.extend([None] * (ndim - len(out))) - return PartitionSpec(*out) + return self.spec._normalized_spec(ndim) def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding: return named_sharding_to_xla_hlo_sharding(self, num_dimensions) diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index edfcb031703f..35a0595a0222 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,7 +46,7 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.mesh import AbstractMesh, Mesh, AxisTypes from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -528,17 +528,32 @@ def _unshard_aval(mesh: Mesh, names: AxisNames, aval: core.AbstractValue raise NotImplementedError(f"Unsupported aval type: {type(aval)}") def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue - ) -> core.AbstractValue: + ) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz // prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + axis_names = mesh.axis_names + new_mesh = AbstractMesh( + mesh.shape_tuple, + dict(zip(axis_names, [AxisTypes.Collective] * len(axis_names)))) + new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.shard_aval_handlers[core.ShapedArray] = _shard_shaped_array def _unshard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue,) -> core.AbstractValue: assert isinstance(aval, core.ShapedArray) - return aval.update(tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) - for i, sz in enumerate(aval.shape))) + new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) + for i, sz in enumerate(aval.shape)) + if config.sharding_in_types.value: + spec = _names_to_pspec(names)._normalized_spec(aval.ndim) + new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) + else: + new_sharding = None + return aval.update(shape=new_shape, sharding=new_sharding) core.unshard_aval_handlers[core.ShapedArray] = _unshard_shaped_array # Type-checking diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8fe46c3b83e5..26780c9cdcc8 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -5201,6 +5201,30 @@ def f(x): self.assertArraysEqual(out, np_inp) self.assertEqual(out.sharding, NamedSharding(mesh, P('x', None))) + def test_shard_map_full_manual(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y')) + np_inp = np.arange(16).reshape(8, 2) + arr = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + arr2 = jax.device_put(np_inp, NamedSharding(mesh, P('x', 'y'))) + + def g(x, y): + self.assertTrue(x.sharding.mesh.are_all_axis_types_collective) + self.assertTrue(y.sharding.mesh.are_all_axis_types_collective) + return x * y + + @jax.jit + def f(x, y): + z = shard_map(g, mesh=mesh, in_specs=(x.sharding.spec, y.sharding.spec), + out_specs=P('x', 'y'))(x, y) + self.assertEqual(z.sharding.spec, P('x', 'y')) + out = z * 2 + self.assertEqual(out.sharding.spec, P('x', 'y')) + return out + + out = f(arr, arr2) + self.assertArraysEqual(out, (np_inp * np_inp) * 2) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):