Skip to content

Commit

Permalink
[Take 2] Generalize global jit cpp cache keys so we can add more keys…
Browse files Browse the repository at this point in the history
… than the current donate_argnums.

This allows us to get more cache hits globally. For example:

Before:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
After:

jax.jit(f, out_shardings=s)(arr)
jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

Reverts b615266

PiperOrigin-RevId: 675746175
  • Loading branch information
pschuh authored and Google-ML-Automation committed Sep 17, 2024
1 parent e92a599 commit 86fe463
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 56 deletions.
6 changes: 4 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2726,7 +2726,8 @@ def clear_backends():
pjit._infer_params_cached.cache_clear()
pjit._pjit_lower_cached.cache_clear()
pjit._create_pjit_jaxpr.cache_clear() # pytype: disable=attribute-error
pjit._cpp_pjit_cache.clear()
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
xc._xla.PjitFunctionCache.clear_all()

@atexit.register
Expand Down Expand Up @@ -2755,7 +2756,8 @@ def clear_caches():
util.clear_all_weakref_lru_caches()

# Clear all C++ compiled executable caches for pjit
pjit._cpp_pjit_cache.clear()
pjit._cpp_pjit_cache_fun_only.clear()
pjit._cpp_pjit_cache_explicit_attributes.clear()
pjit._infer_params_cached.cache_clear()
xc._xla.PjitFunctionCache.clear_all()

Expand Down
42 changes: 39 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from collections.abc import Callable, Sequence, Iterable, Iterator
import dataclasses
from functools import partial, lru_cache, cached_property
import functools
import itertools as it
import logging
import math
Expand Down Expand Up @@ -61,6 +62,7 @@
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, AutoLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
Expand Down Expand Up @@ -88,6 +90,7 @@ class WeakRefList(list):
logger = logging.getLogger(__name__)

Index = Union[int, slice, tuple[Union[int, slice], ...]]
PyTreeDef = tree_util.PyTreeDef

NoSharding = sharding_specs.NoSharding
Chunked = sharding_specs.Chunked
Expand Down Expand Up @@ -2904,6 +2907,34 @@ class MeshExecutableFastpathData(NamedTuple):
in_device_local_layouts: Sequence[DeviceLocalLayout | None]


@dataclasses.dataclass(frozen=True, kw_only=True)
class JitGlobalCppCacheKeys:
donate_argnums: tuple[int, ...] | None = None
donate_argnames: tuple[str, ...] | None = None
device: xc.Device | None = None
backend: str | None = None
in_shardings_treedef: PyTreeDef | None = None
in_shardings_leaves: tuple[Any, ...] | None = None
out_shardings_treedef: PyTreeDef | None = None
out_shardings_leaves: tuple[Any, ...] | None = None
in_layouts_treedef: PyTreeDef | None = None
in_layouts_leaves: tuple[Any, ...] | None = None
out_layouts_treedef: PyTreeDef | None = None
out_layouts_leaves: tuple[Any, ...] | None = None
use_resource_env: bool = False

@functools.cached_property
def contains_explicit_attributes(self):
return (self.donate_argnums is not None or
self.donate_argnames is not None or
self.device is not None or
self.backend is not None or
any(not is_unspecified(i) for i in self.in_shardings_leaves) or
any(not is_unspecified(o) for o in self.out_shardings_leaves) or
any(i is not None for i in self.in_layouts_leaves) or
any(o is not None for o in self.out_layouts_leaves))


def reflatten_outputs_for_dispatch(out_tree, out_flat):
# We arrive at dispatch having flattened according to the default
# pytree registry, but we want to re-flatten according to our
Expand Down Expand Up @@ -3017,9 +3048,14 @@ def aot_cache_miss(*args, **kwargs):
fastpath_data = None
return outs, fastpath_data, False # Do not remove cache entry

return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)
if xla_extension_version >= 286:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [],
JitGlobalCppCacheKeys(), tree_util.dispatch_registry, cc_shard_arg)
else:
return xc._xla.pjit(
self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry, cc_shard_arg)

def cc_shard_arg(x, sharding, layout):
return shard_args([sharding], [layout], [x])[0]
Expand Down
117 changes: 82 additions & 35 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
from jax._src.lib.mlir.dialects import func as func_dialect
from jax._src.lib import jax_jit
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src import sharding
from jax._src.mesh import AbstractMesh
from jax._src.sharding_impls import (
Expand Down Expand Up @@ -164,7 +165,6 @@ class PjitInfo(NamedTuple):
keep_unused: bool
inline: bool
abstracted_axes: Any | None
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit

# Hash and compare PjitInfo by identity when used as a cache key.
Expand Down Expand Up @@ -311,14 +311,39 @@ def _cpp_pjit_evict_fn(self):
# The entries are doubled here from the default 4096 because _pjit_call_impl
# also has a cpp dispatch path and that would double the number of entries in
# the global shared cache.
_cpp_pjit_cache = xc._xla.PjitFunctionCache(capacity=8192)
# This cache is only used for jit's with only fun. For example: jax.jit(f)
_cpp_pjit_cache_fun_only = xc._xla.PjitFunctionCache(capacity=8192)

# This cache is used for jit where extra arguments are defined other than the
# fun. For example: jax.jit(f, donate_argnums=...) OR
# jax.jit(f, out_shardings=...), etc. We don't use the same cache because the
# capacity might get full very fast because of all the jitted function in JAX
# which might evict train_step for example.
_cpp_pjit_cache_explicit_attributes = xc._xla.PjitFunctionCache(capacity=8192)

def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache

if xla_extension_version < 286:
def _get_cpp_global_cache(pjit_has_explicit_sharding):
if pjit_has_explicit_sharding:
return xc._xla.PjitFunctionCache()
else:
return _cpp_pjit_cache_fun_only

def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))
else:
def _get_cpp_global_cache(contains_explicit_attributes: bool): # type: ignore
if contains_explicit_attributes:
return _cpp_pjit_cache_explicit_attributes
else:
return _cpp_pjit_cache_fun_only


def _cpp_pjit(fun: Callable, jit_info: PjitInfo):
Expand All @@ -339,29 +364,42 @@ def cache_miss(*args, **kwargs):

return outs, maybe_fastpath_data, _need_to_rebuild_with_fdo(pgle_profiler)

cpp_pjit_f = xc._xla.pjit(
fun_name(fun),
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(jit_info.has_explicit_sharding))
if xla_extension_version >= 286:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=jit_info.donate_argnums,
donate_argnames=jit_info.donate_argnames,
device=jit_info.device, backend=jit_info.backend,
in_shardings_treedef=jit_info.in_shardings_treedef,
in_shardings_leaves=jit_info.in_shardings_leaves,
out_shardings_treedef=jit_info.out_shardings_treedef,
out_shardings_leaves=jit_info.out_shardings_leaves,
in_layouts_treedef=jit_info.in_layouts_treedef,
in_layouts_leaves=jit_info.in_layouts_leaves,
out_layouts_treedef=jit_info.out_layouts_treedef,
out_layouts_leaves=jit_info.out_layouts_leaves,
use_resource_env=jit_info.use_resource_env)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, cache_key, tree_util.dispatch_registry, # type: ignore
pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
jit_info.in_shardings_leaves, jit_info.out_shardings_leaves,
jit_info.in_layouts_leaves, jit_info.out_layouts_leaves,
jit_info.device, jit_info.backend)
cpp_pjit_f = xc._xla.pjit(
fun_name(fun), fun, cache_miss, jit_info.static_argnums,
jit_info.static_argnames, jit_info.donate_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))

cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
return cpp_pjitted_f


def _pjit_explicit_sharding_and_layout(
in_shardings_flat, out_shardings_flat, in_layouts_flat, out_layouts_flat,
device, backend) -> bool:
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(o) for o in out_shardings_flat) or
any(i is not None for i in in_layouts_flat) or
any(o is not None for o in out_layouts_flat))


def _split_layout_and_sharding(entries):
entries_flat, treedef = tree_flatten(entries, is_leaf=lambda x: x is None)
layouts, shardings = [], []
Expand Down Expand Up @@ -445,10 +483,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
fun, fun_signature, donate_argnums, donate_argnames, static_argnums,
static_argnames)

has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings_leaves, out_shardings_leaves, in_layouts_leaves,
out_layouts_leaves, device, backend)

return PjitInfo(
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
Expand All @@ -466,7 +500,6 @@ def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)


Expand Down Expand Up @@ -1706,13 +1739,27 @@ def call_impl_cache_miss(*args_, **kwargs_):
f = _get_jaxpr_as_fun(
jaxpr, in_shardings, out_shardings, in_layouts, out_layouts,
resource_env, donated_invars, name, keep_unused, inline)
donated_argnums = [i for i, d in enumerate(donated_invars) if d]
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)
donated_argnums = tuple(i for i, d in enumerate(donated_invars) if d)
if xla_extension_version >= 286:
cache_key = pxla.JitGlobalCppCacheKeys(
donate_argnums=donated_argnums, donate_argnames=None,
device=None, backend=None,
in_shardings_treedef=None, in_shardings_leaves=in_shardings,
out_shardings_treedef=None, out_shardings_leaves=out_shardings,
in_layouts_treedef=None, in_layouts_leaves=in_layouts,
out_layouts_treedef=None, out_layouts_leaves=out_layouts,
use_resource_env=resource_env is not None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], cache_key,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(cache_key.contains_explicit_attributes))(*args)
else:
has_explicit_sharding = _pjit_explicit_sharding_and_layout(
in_shardings, out_shardings, in_layouts, out_layouts, None, None)
return xc._xla.pjit(
name, f, call_impl_cache_miss, [], [], donated_argnums,
tree_util.dispatch_registry, pxla.cc_shard_arg,
_get_cpp_global_cache(has_explicit_sharding))(*args)

pjit_p.def_impl(_pjit_call_impl)

Expand Down
11 changes: 5 additions & 6 deletions jax/experimental/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,19 +90,17 @@ def sync_global_devices(name: str):
assert_equal(h, f"sync_global_devices name mismatch ('{name}')")


# Identity function is at the top level so that `process_allgather` doesn't
# recompile on every invocation.
def _identity_fn(x):
return x

@lru_cache(maxsize=128)
def _jitted_identity_fn(sharding):
return jax.jit(_identity_fn, out_shardings=sharding)


def _handle_array_process_allgather(inp, tiled):
if isinstance(inp, array.ArrayImpl) and not inp.is_fully_addressable:
reps = sharding_impls.GSPMDSharding.get_replicated(
inp.sharding._device_assignment)
out = _jitted_identity_fn(reps)(inp)
out = jax.jit(_identity_fn, out_shardings=reps)(inp)
else:
# All inputs here will be fully addressable.
if jax.process_count() == 1:
Expand All @@ -125,7 +123,8 @@ def _handle_array_process_allgather(inp, tiled):
bufs = [jax.device_put(host_np_arr, d) for d in jax.local_devices()]
global_arr = array.make_array_from_single_device_arrays(
global_aval.shape, s, bufs)
out = _jitted_identity_fn(jax.NamedSharding(global_mesh, P()))(global_arr)
out = jax.jit(_identity_fn,
out_shardings=jax.NamedSharding(global_mesh, P()))(global_arr)

return np.asarray(out.addressable_data(0))

Expand Down
33 changes: 23 additions & 10 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib import xla_extension_version
from jax._src.util import curry, unzip2

config.parse_flags_with_absl()
Expand Down Expand Up @@ -652,18 +653,16 @@ def testAutodiff(self, mesh, resources):

@jtu.with_mesh([('x', 2), ('y', 1)])
def testAutodiffCache(self):
f = pjit(
lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None
)
f = pjit(lambda x: jnp.sin(x).sum(), in_shardings=P('x'), out_shardings=None)
x = jnp.arange(16, dtype=jnp.float32)
jax.grad(f)(x) # Warm up the cache.
before = pjit_lib._pjit_lower_cached.cache_info()
jax.grad(f)(x)
after = pjit_lib._pjit_lower_cached.cache_info()

# One hit for the forward pass, one hit for backward.
self.assertEqual(after.hits, before.hits + 2)
self.assertEqual(after.misses, before.misses)
jax.grad(f)(x) # Warm up the cache.
with jtu.count_pjit_cpp_cache_miss() as count:
jax.grad(f)(x)
if xla_extension_version >= 286:
self.assertEqual(count[0], 0) # no cache miss i.e. cache hit
else:
self.assertEqual(count[0], 2)

@jtu.with_mesh([('x', 2), ('y', 1)])
def testEvalJaxpr(self):
Expand Down Expand Up @@ -4531,6 +4530,20 @@ def test_wsc_abstract_mesh_errors(self):
' match the mesh shape of the target sharding.*'):
with_sharding_constraint(arr, NamedSharding(abs_mesh2, P('y')))

@unittest.skipIf(xla_extension_version < 286,
"Requires xla_extension_version >= 286")
def test_global_jit_cpp_cache_hit_out_shardings(self):
mesh = jtu.create_mesh((2,), 'x')
s = NamedSharding(mesh, P('x'))

def f(x):
return x * 2

with jtu.count_pjit_cpp_cache_miss() as count:
jax.jit(f, out_shardings=s)(np.arange(8))
jax.jit(f, out_shardings=s)(np.arange(8))
self.assertEqual(count[0], 1)


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit 86fe463

Please sign in to comment.