Skip to content

Commit

Permalink
Skip the global jit cpp cache if in/out_layouts are not None
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 665085182
  • Loading branch information
yashk2810 authored and jax authors committed Aug 20, 2024
1 parent d148ade commit 1ab6279
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 35 deletions.
23 changes: 11 additions & 12 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def shard_args(shardings: Sequence[JSharding], layouts, args,
] = {}


@lru_cache(maxsize=2048)
def is_default_layout(curr_layout, sharding, aval):
if curr_layout is None or sharding is None:
return True
Expand Down Expand Up @@ -2548,12 +2549,6 @@ def is_user_xla_layout_equal(ul: DeviceLocalLayout | AutoLayout,
else:
return ul == xl

def _check_user_xla_layout(ul, xl, what: str):
if not is_user_xla_layout_equal(ul, xl):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {xl} != {ul} "
f"(User {what} layout)")


def _get_layouts_from_executable(
xla_executable, in_layouts, out_layouts, num_ordered_effects
Expand All @@ -2569,19 +2564,23 @@ def _get_layouts_from_executable(
out_layouts_xla = out_layouts_xla[num_ordered_effects:]

new_in_layouts = []
for x, i in safe_zip(in_layouts_xla, in_layouts):
for x, l in safe_zip(in_layouts_xla, in_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(i, DeviceLocalLayout):
_check_user_xla_layout(i, x, "input")
if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {l} "
f"(User input layout)")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_in_layouts.append(x)

new_out_layouts = []
for x, o in safe_zip(out_layouts_xla, out_layouts):
for x, l in safe_zip(out_layouts_xla, out_layouts):
x = DeviceLocalLayout.from_pjrt_layout(x)
if isinstance(o, DeviceLocalLayout):
_check_user_xla_layout(o, x, "output")
if isinstance(l, DeviceLocalLayout) and not is_user_xla_layout_equal(l, x):
raise AssertionError(
f"Unexpected XLA layout override: (XLA) {x} != {l} "
f"(User output layout)")
# Always append the XLA layout because it has the full information
# (tiling, etc) even if the user layout does not specify tiling.
new_out_layouts.append(x)
Expand Down
30 changes: 13 additions & 17 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,14 +351,15 @@ def cache_miss(*args, **kwargs):
return cpp_pjitted_f


def _pjit_explicit_sharding(in_shardings, out_shardings, device,
backend) -> bool:
in_shardings_flat, _ = tree_flatten(in_shardings)
out_shardings_flat, _ = tree_flatten(out_shardings)
def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))


def _split_layout_and_sharding(entries):
Expand Down Expand Up @@ -444,8 +445,9 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)

has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
out_layouts_leaves, device, backend)

return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
Expand Down Expand Up @@ -1723,8 +1725,8 @@ def call_impl_cache_miss(*args_, **kwargs_):
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, None, None)
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
Expand Down Expand Up @@ -1753,14 +1755,8 @@ def _pjit_lower_cached(
lowering_platforms: tuple[str, ...] | None,
lowering_parameters: mlir.LoweringParameters,
pgle_profiler: profiler.PGLEProfiler | None):
if resource_env is not None:
mesh = resource_env.physical_mesh
api_name = 'pjit'
else:
# resource_env is `None` in the jit wrapper around pjit.
mesh = None
api_name = 'jit'

mesh, api_name = ((resource_env.physical_mesh, 'pjit')
if resource_env is not None else (None, 'jit'))
return pxla.lower_sharding_computation(
jaxpr, api_name, name, in_shardings, out_shardings,
in_layouts, out_layouts, tuple(donated_invars),
Expand Down
12 changes: 6 additions & 6 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from jax._src.interpreters import pxla
from jax.interpreters import xla
from jax._src import pjit as pjit_lib
from jax.experimental.pjit import pjit
from jax.sharding import PartitionSpec as P
from jax._src import distributed
from jax._src.util import safe_zip
Expand Down Expand Up @@ -91,17 +90,19 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")


# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x

@lru_cache(maxsize=128)
def _jitted_identity_fn(sharding):
return jax.jit(_identity_fn, out_shardings=sharding)


def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
out = pjit(_identity_fn, out_shardings=reps)(inp)
out = _jitted_identity_fn(reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
Expand All @@ -124,8 +125,7 @@ def _handle_array_process_allgather(inp, tiled):
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
with global_mesh:
out = pjit(_identity_fn, out_shardings=None)(global_arr)
out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)

return np.asarray(out.addressable_data(0))

Expand Down

0 comments on commit 1ab6279

Please sign in to comment.