Skip to content

Commit

Permalink
Standardize default layout to None in internals (dispatch, lowering…
Browse files Browse the repository at this point in the history
… and compilation) and non-default layouts to concrete layouts.

This massively simplifies the amount of checks we need and improves dispatch time too. It also fixes a donation bug being hit in serving code related to layouts and non-standardization of default layout in JAX.

PiperOrigin-RevId: 668527139
  • Loading branch information
yashk2810 authored and jax authors committed Aug 28, 2024
1 parent 4695705 commit ef33cf5
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 92 deletions.
11 changes: 3 additions & 8 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray:
and sharding.is_fully_replicated
and first_value.is_fully_replicated
and first_value.sharding._device_assignment == tuple(devices)
and (first_value.layout.device_local_layout ==
pxla._maybe_get_default_layout(Layout(dll, sharding), None, sharding, aval))):
and first_value.layout.device_local_layout == dll):
return first_value

if dtypes.issubdtype(aval.dtype, dtypes.extended):
Expand Down Expand Up @@ -1105,11 +1104,6 @@ def _sharding_indices_and_eq(src_sharding, shape, dst_sharding):
dst_indices = dst_sharding.addressable_devices_indices_map(shape).values()
return dst_indices, tuple(src_indices) == tuple(dst_indices)

def _layout_eq(x, dst_layout, sharding):
if pxla.is_default_layout(dst_layout, sharding, x.aval):
return True
return x.layout.device_local_layout == dst_layout


def _array_shard_arg(xs, shardings, layouts):
results = []
Expand All @@ -1118,7 +1112,8 @@ def _array_shard_arg(xs, shardings, layouts):
for i, (x, sharding, layout) in enumerate(safe_zip(xs, shardings, layouts)):
x._check_if_deleted()
indices, same_indices = _sharding_indices_and_eq(x.sharding, x.shape, sharding)
same_layout = _layout_eq(x, layout, sharding)
same_layout = (True if layout is None else
x.layout.device_local_layout == layout)

if not x.is_fully_addressable:
if same_indices and same_layout:
Expand Down
23 changes: 11 additions & 12 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ def lower_jaxpr_to_module(
result_memory_kinds = (map(_get_mem_kind, result_shardings)
if result_shardings is not None else None)

# TODO(yashkatariya): Simplify the donation logic.
xla_donated_args = None
platforms_with_donation = [p for p in platforms
if p in _platforms_with_donation]
Expand All @@ -1071,9 +1072,6 @@ def lower_jaxpr_to_module(
input_output_aliases, donated_args, xla_donated_args = _set_up_aliases(
input_output_aliases, in_avals, out_avals, donated_args,
arg_memory_kinds, result_memory_kinds, in_layouts, out_layouts)
unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')
if any(donated_args):
unused_donations = [str(a) for a, d in zip(in_avals, donated_args) if d]
msg = "See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation."
Expand All @@ -1082,10 +1080,13 @@ def lower_jaxpr_to_module(
if unused_donations:
warnings.warn("Some donated buffers were not usable:"
f" {', '.join(unused_donations)}.\n{msg}")

# Delete donated_args by default here, since it's not needed beyond this point
del donated_args

unlowerable_effects = lowerable_effects.filter_not_in(jaxpr.effects)
if unlowerable_effects:
raise ValueError(f'Cannot lower jaxpr with effects: {jaxpr.effects}')

# HLO channels need to start at 1
channel_iter = itertools.count(1)
# Create a keepalives list that will be mutated during the lowering.
Expand Down Expand Up @@ -1167,8 +1168,7 @@ def emit_diagnostic_info(d):


def _set_up_aliases(input_output_aliases, avals_in, avals_out,
donated_args,
arg_memory_kinds, result_memory_kinds,
donated_args, arg_memory_kinds, result_memory_kinds,
in_layouts, out_layouts):
if input_output_aliases is None:
input_output_aliases = [None] * len(avals_in)
Expand Down Expand Up @@ -1207,15 +1207,14 @@ def _set_up_aliases(input_output_aliases, avals_in, avals_out,
if donations.get(key, ()):
input_id = donations[key].popleft()
out_donated_args[input_id] = False
# We can alias if XLA performs layout assignment because XLA will
# respect the aliases when assigning layouts. Its only for two
# mismatched explicitly assigned layouts that XLA will certainly fail.
if (in_layouts is None or
out_layouts is None or
in_layouts[input_id] == out_layouts[i] or
# We can alias if XLA performs layout assignment because XLA will
# respect the aliases when assigning layouts. Its only for two
# mismatched explicitly assigned layouts that XLA will certainly
# fail.
isinstance(in_layouts[input_id], (AutoLayout, type(None))) or
isinstance(out_layouts[i], (AutoLayout, type(None)))):
isinstance(in_layouts[input_id], AutoLayout) or
isinstance(out_layouts[i], AutoLayout)):
input_output_aliases[input_id] = i
else:
# Fallback to xla donation if layouts don't match.
Expand Down
93 changes: 32 additions & 61 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args,

@lru_cache(maxsize=2048)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None:
if curr_layout is None or sharding is None or is_unspecified(sharding):
return True
if (aval is core.abstract_token or aval.dtype == dtypes.float0 or
dtypes.issubdtype(aval.dtype, dtypes.extended)):
Expand Down Expand Up @@ -191,7 +191,7 @@ def _shard_np_array(xs, shardings, layouts):
if x.dtype == dtypes.float0:
x = np.zeros(x.shape, dtype=np.dtype(bool))
aval = api_util.shaped_abstractify(x)
if not is_default_layout(layout, sharding, aval):
if layout is not None:
results.append(api.device_put(x, Layout(layout, sharding)))
else:
if sharding.is_fully_replicated:
Expand Down Expand Up @@ -1884,35 +1884,6 @@ def _raise_warnings_or_errors_for_jit_of_pmap(
"extra data movement anyway, so maybe you don't want it after all).")


@lru_cache(maxsize=2048)
def _maybe_get_default_layout(arg_layout, jit_in_layout, sharding, aval
) -> DeviceLocalLayout | None:
if is_unspecified_or_auto(sharding):
return None
# TODO(yashkatariya): Figure out how layouts work with extended dtypes.
if aval is core.abstract_token or dtypes.issubdtype(aval.dtype, dtypes.extended):
return None
if not core.is_constant_shape(aval.shape):
return None
shard_shape = sharding.shard_shape(aval.shape)
d = sharding._device_assignment[0]
# If a backend doesn't implement `get_default_layout` return `None` to avoid
# cache misses. This can happen when you have `jit(f, in_shardings=s)`. On
# first call you pass it a sharded array with layout and on second call you
# pass a numpy array. The layouts should be the same to get cache hits.
try:
al = DeviceLocalLayout.from_pjrt_layout(
d.client.get_default_layout(aval.dtype, shard_shape, d))
except:
return None
# argument does not have `.layout` property. ShapedArray, numpy array, etc
# are some examples.
if arg_layout is None:
return al if jit_in_layout is None else arg_layout # arg_layout is None
# If arg has a `.layout` property, then return device_local_layout as is.
return arg_layout.device_local_layout


@weakref_lru_cache
def _cached_lowering_to_hlo(closed_jaxpr, api_name, fun_name, backend,
semantic_in_shardings, semantic_out_shardings,
Expand Down Expand Up @@ -2775,13 +2746,14 @@ class UnloadedMeshExecutable:
kept_var_idx: set[int]
mut: MutationData | None
auto_spmd_lowering: bool
in_layouts: Sequence[DeviceLocalLayout | None]
out_layouts: Sequence[DeviceLocalLayout | None]
xla_in_layouts: Sequence[DeviceLocalLayout | None]
dispatch_in_layouts: Sequence[DeviceLocalLayout | None]
xla_out_layouts: Sequence[DeviceLocalLayout | None]
all_args_info: AllArgsInfo | None
pgle_profiler: profiler.PGLEProfiler | None

def build_unsafe_call(self):
handle_args = InputsHandler(self.input_shardings, self.in_layouts)
handle_args = InputsHandler(self.input_shardings, self.dispatch_in_layouts)
handle_outs = global_avals_to_results_handler(
self.output_avals, self.output_shardings, self.committed)

Expand All @@ -2797,8 +2769,8 @@ def load(self) -> MeshExecutable:
self.input_avals, self.output_avals,
self.input_shardings, self.output_shardings,
self.auto_spmd_lowering, self.kept_var_idx,
self.in_layouts, self.out_layouts,
self.all_args_info, self)
self.xla_in_layouts, self.dispatch_in_layouts,
self.xla_out_layouts, self.all_args_info, self)

@staticmethod
def from_hlo(name: str,
Expand Down Expand Up @@ -2881,8 +2853,18 @@ def from_hlo(name: str,
in_shardings, out_shardings, committed, da = _get_metadata_jit_pmap(
xla_executable.local_devices(), len(in_shardings), len(out_shardings))

in_layouts, out_layouts = _get_layouts_from_executable(
# xla_in_layouts are all either None or DeviceLocalLayout. Even default
# layout are concrete layouts and they are used in `compiled.input_layouts`
# to return concrete layouts to users.
# `dispatch_in_layouts` replaces default layouts with `None` to simplify
# dispatch logic downstream.
xla_in_layouts, xla_out_layouts = _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, len(ordered_effects))
del in_layouts, out_layouts
dispatch_in_layouts = [
None if is_default_layout(l, s, a) else l
for l, s, a, in safe_zip(xla_in_layouts, in_shardings, global_in_avals)
]

out_shardings = maybe_recover_user_shardings(
in_shardings, out_shardings, global_in_avals, global_out_avals,
Expand All @@ -2907,8 +2889,9 @@ def from_hlo(name: str,
kept_var_idx=kept_var_idx,
mut=mut,
auto_spmd_lowering=auto_spmd_lowering,
in_layouts=in_layouts,
out_layouts=out_layouts,
xla_in_layouts=xla_in_layouts,
dispatch_in_layouts=dispatch_in_layouts,
xla_out_layouts=xla_out_layouts,
all_args_info=all_args_info,
pgle_profiler=pgle_profiler).load()

Expand Down Expand Up @@ -2964,13 +2947,13 @@ class MeshExecutable(stages.XlaExecutable):
__slots__ = [
"xla_executable", "_unsafe_call", "build_unsafe_call", "in_avals",
"out_avals", "_in_shardings", "_out_shardings", "_auto_spmd_lowering",
"_kept_var_idx", "_in_layouts", "_out_layouts", "_all_args_info",
"_unloaded_executable",
"_kept_var_idx", "_xla_in_layouts", "_dispatch_in_layouts",
"_xla_out_layouts", "_all_args_info", "_unloaded_executable",
]

def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx,
in_layouts, out_layouts,
xla_in_layouts, dispatch_in_layouts, xla_out_layouts,
all_args_info: AllArgsInfo | None = None,
unloaded_executable=None):
self.xla_executable = xla_executable
Expand All @@ -2984,8 +2967,9 @@ def __init__(self, xla_executable, build_unsafe_call, in_avals, out_avals,
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx
self._in_layouts = in_layouts
self._out_layouts = out_layouts
self._xla_in_layouts = xla_in_layouts
self._dispatch_in_layouts = dispatch_in_layouts
self._xla_out_layouts = xla_out_layouts
self._all_args_info = all_args_info
self._unloaded_executable = unloaded_executable

Expand Down Expand Up @@ -3013,9 +2997,8 @@ def call(self, *args):

all_arg_avals = map(xla.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
# Check the GDA sharding and the input sharding.
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._in_layouts, debug_info,
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,
self._kept_var_idx)
return self.unsafe_call(*args) # pylint: disable=not-callable

Expand All @@ -3027,11 +3010,11 @@ def output_shardings(self) -> Sequence[JSharding]:

def input_layouts(self):
return [Layout(l, s)
for l, s in safe_zip(self._in_layouts, self._in_shardings)]
for l, s in safe_zip(self._xla_in_layouts, self._in_shardings)]

def output_layouts(self):
return [Layout(l, s)
for l, s in safe_zip(self._out_layouts, self._out_shardings)]
for l, s in safe_zip(self._xla_out_layouts, self._out_shardings)]

def create_cpp_call(self, no_kwargs, in_tree, out_tree):
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
Expand All @@ -3057,12 +3040,10 @@ def aot_cache_miss(*args, **kwargs):
else s
for s, a in zip(self._in_shardings, self.in_avals)
]
in_dlls = get_layouts_for_fasthpath_data(
self._in_layouts, in_shardings, self.in_avals)
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
in_dlls)
self._dispatch_in_layouts)
else:
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry
Expand All @@ -3084,16 +3065,6 @@ def cc_shard_arg(x, sharding, layout): # type: ignore
return shard_args([sharding], [layout], [x])[0]


def get_layouts_for_fasthpath_data(in_layouts, in_shardings, in_avals):
in_dlls = []
for l, s, a in zip(in_layouts, in_shardings, in_avals):
if is_default_layout(l, s, a):
in_dlls.append(None)
else:
in_dlls.append(l)
return in_dlls


def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
if len(ref_avals) != len(arg_avals):
Expand Down
21 changes: 13 additions & 8 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,10 @@ def _get_fastpath_data(
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
in_dlls = pxla.get_layouts_for_fasthpath_data(
executable._in_layouts, in_shardings, executable.in_avals)
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
in_dlls)
executable._dispatch_in_layouts)
else:
fastpath_data = None
return fastpath_data
Expand Down Expand Up @@ -1479,10 +1477,17 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
resolved_in_layouts = []
for arg, jit_in_l, rs, aval in safe_zip(
args, jit_in_layouts, resolved_in_shardings, in_avals):
arg_layout, committed = (
pxla._maybe_get_default_layout(getattr(arg, 'layout', None), jit_in_l,
rs, aval),
getattr(arg, '_committed', True))
committed = getattr(arg, '_committed', True)
# `arg_layout` is only used for checking purposes in the `else` branch
# below. We cannot replace default layout with None to raise nicer errors.
# `dispatch_arg_layout` replaces default layouts with `None` to simplify
# dispatch and lowering logic downstream.
if hasattr(arg, 'layout'):
arg_layout = arg.layout.device_local_layout
dispatch_arg_layout = (None if pxla.is_default_layout(arg_layout, rs, aval)
else arg_layout)
else:
arg_layout, dispatch_arg_layout = None, None
# Sharding can be unspecified when array is committed if it's a PmapSharding.
is_pmap_sharding = (is_unspecified(rs) or
isinstance(getattr(arg, 'sharding', None), PmapSharding))
Expand All @@ -1491,7 +1496,7 @@ def _resolve_in_layouts(args, jit_in_layouts, resolved_in_shardings, in_avals):
if is_pmap_sharding:
resolved_in_layouts.append(None)
else:
resolved_in_layouts.append(arg_layout)
resolved_in_layouts.append(dispatch_arg_layout)
else:
resolved_in_layouts.append(None)
else:
Expand Down
24 changes: 21 additions & 3 deletions tests/layout_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def test_layout_donation(self):
def f(x):
return x

out = f(arr)
f(arr)
self.assertTrue(arr.is_deleted())

def test_layout_donation_auto(self):
Expand All @@ -555,7 +555,7 @@ def test_layout_donation_auto(self):
def f(x):
return x * x

out = f(arr)
f(arr)
self.assertTrue(arr.is_deleted())

def test_layout_donation_matching_in_and_out(self):
Expand All @@ -572,9 +572,27 @@ def test_layout_donation_matching_in_and_out(self):
def f(x):
return x * x

out = f(arr)
f(arr)
self.assertTrue(arr.is_deleted())

def test_layout_donation_mismatching_in_and_out_fails(self):
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
shape = (16*2, 32016*2)
np_inp = np.arange(math.prod(shape), dtype=jnp.bfloat16).reshape(shape)

custom_dll1 = DLL(major_to_minor=(1, 0), _tiling=((8,128), (2,1)))
l1 = Layout(custom_dll1, s)
arr = jax.device_put(np_inp, s)

@partial(jax.jit, out_shardings=l1, donate_argnums=0)
def f(x):
return x * x

sds = jax.ShapeDtypeStruct(np_inp.shape, np_inp.dtype, sharding=s)
f.lower(sds).compile()(arr)
self.assertFalse(arr.is_deleted())


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit ef33cf5

Please sign in to comment.