diff --git a/jax/_src/cache_key.py b/jax/_src/cache_key.py index 6fdf0c600b7d..9bce9d0e4308 100644 --- a/jax/_src/cache_key.py +++ b/jax/_src/cache_key.py @@ -83,7 +83,8 @@ def get(module: ir.Module, 'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf' """ entries = [ - ("computation", lambda hash_obj: _hash_computation(hash_obj, module)), + ("computation", + lambda hash_obj: _hash_computation(hash_obj, module)), ("jax_lib version", lambda hash_obj: hash_obj.update( bytes(jaxlib_version_str.encode("utf-8")))), @@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn): ) +def _remove_custom_partitioning_ptr(m: ir.Module): + """ + Removes custom_partitioning callback pointer from precompiled IR. + Python function pointers are not deterministic across executions. + """ + def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult: + if (op.name == "stablehlo.custom_call" and + op.attributes["call_target_name"].value == "CustomSPMDPartitioning"): + op.attributes["backend_config"] = ir.StringAttr.get("REMOVED") + return ir.WalkResult.ADVANCE + + m.operation.walk(_update_bc_attribute) + return m + + def _serialize_ir(m: ir.Module) -> bytes: output = io.BytesIO() + if config.remove_custom_partitioning_ptr_from_cache_key.value: + m = _remove_custom_partitioning_ptr(type_cast(ir.Module, + m.operation.clone())) m.operation.write_bytecode(file=output) return output.getvalue() diff --git a/jax/_src/compilation_cache.py b/jax/_src/compilation_cache.py index 8117f871a969..b946dc0a2897 100644 --- a/jax/_src/compilation_cache.py +++ b/jax/_src/compilation_cache.py @@ -265,7 +265,9 @@ def put_executable_and_time( cache.put(cache_key, executable_and_time) -def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options, +def get_cache_key(module: ir.Module, + devices: np.ndarray, + compile_options, backend) -> str: return cache_key.get(module, devices, compile_options, backend, "zstandard" if zstandard is not None else "zlib") diff --git a/jax/_src/config.py b/jax/_src/config.py index b2d1aa52ef2a..51d7dab585af 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -1347,6 +1347,16 @@ def _update_jax_memories_thread_local(val): 'size to grow indefinitely.'), ) +remove_custom_partitioning_ptr_from_cache_key = bool_state( + name='jax_remove_custom_partitioning_ptr_from_cache_key', + default=False, + help=('If set to True, remove the custom partitioning pointer ' + 'present in the precompiled stableHLO before hashing ' + 'during cache key computation. This is a potentially ' + 'unsafe flag to set and only users who are sure of ' + 'what they are trying to achieve should set it.'), +) + default_dtype_bits = enum_state( name='jax_default_dtype_bits', enum_values=['32', '64'], diff --git a/tests/cache_key_test.py b/tests/cache_key_test.py index 508dbacc2a98..00925c5f7dfc 100644 --- a/tests/cache_key_test.py +++ b/tests/cache_key_test.py @@ -14,8 +14,10 @@ import hashlib import os +import re import sys import unittest +from typing import cast as type_cast import numpy as np @@ -29,6 +31,11 @@ from jax._src import test_util as jtu from jax._src import xla_bridge from jax._src.lib import xla_client +from jax._src.lib.mlir import ir +from jax._src.mesh import Mesh +from jax._src.partition_spec import PartitionSpec as P +from jax._src.sharding_impls import NamedSharding +from jax._src.custom_partitioning import custom_partitioning config.parse_flags_with_absl() @@ -155,6 +162,49 @@ def test_different_computations(self): cache_key.get(computation2, devices, compile_options, backend), ) + def test_custom_partitioning_ptr_removal(self): + def _partition(mesh, arg_shapes, result_shape): + arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes) + result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec) + return mesh, jax.numpy.add, result_shardings, arg_shardings + + def _infer_sharding_from_operands(mesh, arg_shapes, result_shape): + return NamedSharding(mesh, arg_shapes[0].sharding.spec) + + @custom_partitioning + def _cp_add(x, y): + return jax.numpy.add(x, y) + + _cp_add.def_partition( + infer_sharding_from_operands=_infer_sharding_from_operands, + partition=_partition) + + devices = np.asarray(jax.devices()) + with Mesh(devices, ('x',)) as m: + computation = jax.jit( + _cp_add, + in_shardings=(NamedSharding(m, P('x')), + NamedSharding(m, P('x'))), + out_shardings=NamedSharding(m, P('x')) + ).lower( + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32), + ).compiler_ir() + pattern = ( + r'stablehlo\.custom_call @CustomSPMDPartitioning\(' + r'(.*?)\) \{' + r'(.*?backend_config\s*=\s*"([^"]*)".*?)' + r'\}' + ) + with config.remove_custom_partitioning_ptr_from_cache_key(True): + with computation.context: + updated_module = cache_key._remove_custom_partitioning_ptr( + type_cast(ir.Module, computation.operation.clone())) + bcs = [match[2] for + match in re.findall(pattern, str(updated_module), re.DOTALL)] + for bc in bcs: + self.assertEqual(bc, "REMOVED") + def test_different_device_assignment(self): computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir() devices = np.array([[jax.local_devices()[0]]])