diff --git a/jax/_src/api.py b/jax/_src/api.py index 8ca3803aec35..b548cc43fb3b 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2726,7 +2726,8 @@ def clear_backends(): pjit._infer_params_cached.cache_clear() pjit._pjit_lower_cached.cache_clear() pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() xc._xla.PjitFunctionCache.clear_all() @atexit.register @@ -2755,7 +2756,8 @@ def clear_caches(): util.clear_all_weakref_lru_caches() # Clear all C++ compiled executable caches for pjit - pjit._cpp_pjit_cache.clear() + pjit._cpp_pjit_cache_fun_only.clear() + pjit._cpp_pjit_cache_explicit_attributes.clear() pjit._infer_params_cached.cache_clear() xc._xla.PjitFunctionCache.clear_all() diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index b7d68f73c2a4..944e20fa7faa 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -22,6 +22,7 @@ from collections.abc import Callable, Sequence, Iterable, Iterator import dataclasses from functools import partial, lru_cache, cached_property +import functools import itertools as it import logging import math @@ -61,6 +62,7 @@ from jax._src.interpreters import xla from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec @@ -88,6 +90,7 @@ class WeakRefList(list): logger = logging.getLogger(__name__) Index = Union[int, slice, tuple[Union[int, slice], ...]] +PyTreeDef = tree_util.PyTreeDef NoSharding = sharding_specs.NoSharding Chunked = sharding_specs.Chunked @@ -2904,6 +2907,34 @@ class MeshExecutableFastpathData(NamedTuple): in_device_local_layouts: Sequence[DeviceLocalLayout | None] +@dataclasses.dataclass(frozen=True, kw_only=True) +class JitGlobalCppCacheKeys: + donate_argnums: tuple[int, ...] | None = None + donate_argnames: tuple[str, ...] | None = None + device: xc.Device | None = None + backend: str | None = None + in_shardings_treedef: PyTreeDef | None = None + in_shardings_leaves: tuple[Any, ...] | None = None + out_shardings_treedef: PyTreeDef | None = None + out_shardings_leaves: tuple[Any, ...] | None = None + in_layouts_treedef: PyTreeDef | None = None + in_layouts_leaves: tuple[Any, ...] | None = None + out_layouts_treedef: PyTreeDef | None = None + out_layouts_leaves: tuple[Any, ...] | None = None + use_resource_env: bool = False + + @functools.cached_property + def contains_explicit_attributes(self): + return (self.donate_argnums is not None or + self.donate_argnames is not None or + self.device is not None or + self.backend is not None or + any(not is_unspecified(i) for i in self.in_shardings_leaves) or + any(not is_unspecified(o) for o in self.out_shardings_leaves) or + any(i is not None for i in self.in_layouts_leaves) or + any(o is not None for o in self.out_layouts_leaves)) + + def reflatten_outputs_for_dispatch(out_tree, out_flat): # We arrive at dispatch having flattened according to the default # pytree registry, but we want to re-flatten according to our @@ -3017,9 +3048,14 @@ def aot_cache_miss(*args, **kwargs): fastpath_data = None return outs, fastpath_data, False # Do not remove cache entry - return xc._xla.pjit( - self.unsafe_call.name, None, aot_cache_miss, [], [], [], - tree_util.dispatch_registry, cc_shard_arg) + if xla_extension_version >= 286: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], + JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg) + else: + return xc._xla.pjit( + self.unsafe_call.name, None, aot_cache_miss, [], [], [], + tree_util.dispatch_registry, cc_shard_arg) def cc_shard_arg(x, sharding, layout): return shard_args([sharding], [layout], [x])[0] diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fb76f7931c01..42a7c966b4d6 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -62,6 +62,7 @@ from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import jax_jit from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src import sharding from jax._src.mesh import AbstractMesh from jax._src.sharding_impls import ( @@ -164,7 +165,6 @@ class PjitInfo(NamedTuple): keep_unused: bool inline: bool abstracted_axes: Any | None - has_explicit_sharding: bool use_resource_env: bool # False for jit, True for pjit # Hash and compare PjitInfo by identity when used as a cache key. @@ -311,14 +311,39 @@ def _cpp_pjit_evict_fn(self): # The entries are doubled here from the default 4096 because _pjit_call_impl # also has a cpp dispatch path and that would double the number of entries in # the global shared cache. -_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is only used for jit's with only fun. For example: jax.jit(f) +_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192) +# This cache is used for jit where extra arguments are defined other than the +# fun. For example: jax.jit(f, donate_argnums=...) OR +# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the +# capacity might get full very fast because of all the jitted function in JAX +# which might evict train_step for example. +_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192) -def _get_cpp_global_cache(pjit_has_explicit_sharding): - if pjit_has_explicit_sharding: - return xc._xla.PjitFunctionCache() - else: - return _cpp_pjit_cache + +if xla_extension_version < 286: + def _get_cpp_global_cache(pjit_has_explicit_sharding): + if pjit_has_explicit_sharding: + return xc._xla.PjitFunctionCache() + else: + return _cpp_pjit_cache_fun_only + + def _pjit_explicit_sharding_and_layout( + in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, + device, backend) -> bool: + return (device is not None or + backend is not None or + any(not is_unspecified(i) for i in in_shardings_flat) or + any(not is_unspecified(o) for o in out_shardings_flat) or + any(i is not None for i in in_layouts_flat) or + any(o is not None for o in out_layouts_flat)) +else: + def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore + if contains_explicit_attributes: + return _cpp_pjit_cache_explicit_attributes + else: + return _cpp_pjit_cache_fun_only def _cpp_pjit(fun: Callable, jit_info: PjitInfo): @@ -339,11 +364,35 @@ def cache_miss(*args, **kwargs): return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler) - cpp_pjit_f = xc._xla.pjit( - fun_name(fun), - fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames, - jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(jit_info.has_explicit_sharding)) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=jit_info.donate_argnums, + donate_argnames=jit_info.donate_argnames, + device=jit_info.device, backend=jit_info.backend, + in_shardings_treedef=jit_info.in_shardings_treedef, + in_shardings_leaves=jit_info.in_shardings_leaves, + out_shardings_treedef=jit_info.out_shardings_treedef, + out_shardings_leaves=jit_info.out_shardings_leaves, + in_layouts_treedef=jit_info.in_layouts_treedef, + in_layouts_leaves=jit_info.in_layouts_leaves, + out_layouts_treedef=jit_info.out_layouts_treedef, + out_layouts_leaves=jit_info.out_layouts_leaves, + use_resource_env=jit_info.use_resource_env) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore + pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes)) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + jit_info.in_shardings_leaves, jit_info.out_shardings_leaves, + jit_info.in_layouts_leaves, jit_info.out_layouts_leaves, + jit_info.device, jit_info.backend) + cpp_pjit_f = xc._xla.pjit( + fun_name(fun), fun, cache_miss, jit_info.static_argnums, + jit_info.static_argnames, jit_info.donate_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding)) cpp_pjitted_f = wraps(fun)(cpp_pjit_f) cpp_pjitted_f._fun = fun @@ -351,17 +400,6 @@ def cache_miss(*args, **kwargs): return cpp_pjitted_f -def _pjit_explicit_sharding_and_layout( - in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat, - device, backend) -> bool: - return (device is not None or - backend is not None or - any(not is_unspecified(i) for i in in_shardings_flat) or - any(not is_unspecified(o) for o in out_shardings_flat) or - any(i is not None for i in in_layouts_flat) or - any(o is not None for o in out_layouts_flat)) - - def _split_layout_and_sharding(entries): entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None) layouts, shardings = [], [] @@ -445,10 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, fun, fun_signature, donate_argnums, donate_argnames, static_argnums, static_argnames) - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings_leaves, out_shardings_leaves, in_layouts_leaves, - out_layouts_leaves, device, backend) - return PjitInfo( fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature, @@ -466,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any, donate_argnames=donate_argnames, device=device, backend=backend, keep_unused=keep_unused, inline=inline, abstracted_axes=abstracted_axes, - has_explicit_sharding=has_explicit_sharding, use_resource_env=use_resource_env) @@ -1706,13 +1739,27 @@ def call_impl_cache_miss(*args_, **kwargs_): f = _get_jaxpr_as_fun( jaxpr, in_shardings, out_shardings, in_layouts, out_layouts, resource_env, donated_invars, name, keep_unused, inline) - donated_argnums = [i for i, d in enumerate(donated_invars) if d] - has_explicit_sharding = _pjit_explicit_sharding_and_layout( - in_shardings, out_shardings, in_layouts, out_layouts, None, None) - return xc._xla.pjit( - name, f, call_impl_cache_miss, [], [], donated_argnums, - tree_util.dispatch_registry, pxla.cc_shard_arg, - _get_cpp_global_cache(has_explicit_sharding))(*args) + donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d) + if xla_extension_version >= 286: + cache_key = pxla.JitGlobalCppCacheKeys( + donate_argnums=donated_argnums, donate_argnames=None, + device=None, backend=None, + in_shardings_treedef=None, in_shardings_leaves=in_shardings, + out_shardings_treedef=None, out_shardings_leaves=out_shardings, + in_layouts_treedef=None, in_layouts_leaves=in_layouts, + out_layouts_treedef=None, out_layouts_leaves=out_layouts, + use_resource_env=resource_env is not None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], cache_key, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args) + else: + has_explicit_sharding = _pjit_explicit_sharding_and_layout( + in_shardings, out_shardings, in_layouts, out_layouts, None, None) + return xc._xla.pjit( + name, f, call_impl_cache_miss, [], [], donated_argnums, + tree_util.dispatch_registry, pxla.cc_shard_arg, + _get_cpp_global_cache(has_explicit_sharding))(*args) pjit_p.def_impl(_pjit_call_impl) diff --git a/jax/experimental/multihost_utils.py b/jax/experimental/multihost_utils.py index 554bf2641769..56003ea7af5d 100644 --- a/jax/experimental/multihost_utils.py +++ b/jax/experimental/multihost_utils.py @@ -90,19 +90,17 @@ def sync_global_devices(name: str): assert_equal(h, f"sync_global_devices name mismatch ('{name}')") +# Identity function is at the top level so that `process_allgather` doesn't +# recompile on every invocation. def _identity_fn(x): return x -@lru_cache(maxsize=128) -def _jitted_identity_fn(sharding): - return jax.jit(_identity_fn, out_shardings=sharding) - def _handle_array_process_allgather(inp, tiled): if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable: reps = sharding_impls.GSPMDSharding.get_replicated( inp.sharding._device_assignment) - out = _jitted_identity_fn(reps)(inp) + out = jax.jit(_identity_fn, out_shardings=reps)(inp) else: # All inputs here will be fully addressable. if jax.process_count() == 1: @@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled): bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()] global_arr = array.make_array_from_single_device_arrays( global_aval.shape, s, bufs) - out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr) + out = jax.jit(_identity_fn, + out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr) return np.asarray(out.addressable_data(0)) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 6c022653581d..11a541f2e5f5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -57,6 +57,7 @@ from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.util import curry, unzip2 config.parse_flags_with_absl() @@ -652,18 +653,16 @@ def testAutodiff(self, mesh, resources): @jtu.with_mesh([('x', 2), ('y', 1)]) def testAutodiffCache(self): - f = pjit( - lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None - ) + f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None) x = jnp.arange(16, dtype=jnp.float32) - jax.grad(f)(x) # Warm up the cache. - before = pjit_lib._pjit_lower_cached.cache_info() - jax.grad(f)(x) - after = pjit_lib._pjit_lower_cached.cache_info() - # One hit for the forward pass, one hit for backward. - self.assertEqual(after.hits, before.hits + 2) - self.assertEqual(after.misses, before.misses) + jax.grad(f)(x) # Warm up the cache. + with jtu.count_pjit_cpp_cache_miss() as count: + jax.grad(f)(x) + if xla_extension_version >= 286: + self.assertEqual(count[0], 0) # no cache miss i.e. cache hit + else: + self.assertEqual(count[0], 2) @jtu.with_mesh([('x', 2), ('y', 1)]) def testEvalJaxpr(self): @@ -4531,6 +4530,20 @@ def test_wsc_abstract_mesh_errors(self): ' match the mesh shape of the target sharding.*'): with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y'))) + @unittest.skipIf(xla_extension_version < 286, + "Requires xla_extension_version >= 286") + def test_global_jit_cpp_cache_hit_out_shardings(self): + mesh = jtu.create_mesh((2,), 'x') + s = NamedSharding(mesh, P('x')) + + def f(x): + return x * 2 + + with jtu.count_pjit_cpp_cache_miss() as count: + jax.jit(f, out_shardings=s)(np.arange(8)) + jax.jit(f, out_shardings=s)(np.arange(8)) + self.assertEqual(count[0], 1) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")