Skip to content

Commit

Permalink
Merge pull request #22702 from keshavb96:rm_custom_partitioning_pointer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673466750
  • Loading branch information
Google-ML-Automation committed Sep 11, 2024
2 parents c708b7c + 7c660c4 commit 859188b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 2 deletions.
21 changes: 20 additions & 1 deletion jax/_src/cache_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def get(module: ir.Module,
'jit__psum-14ac577cdb2ef6d986078b4054cc9893a9a14a16dbb0d8f37b89167c1f1aacdf'
"""
entries = [
("computation", lambda hash_obj: _hash_computation(hash_obj, module)),
("computation",
lambda hash_obj: _hash_computation(hash_obj, module)),
("jax_lib version",
lambda hash_obj: hash_obj.update(
bytes(jaxlib_version_str.encode("utf-8")))),
Expand Down Expand Up @@ -129,8 +130,26 @@ def _log_cache_key_hash(hash_obj, last_serialized: str, hashfn):
)


def _remove_custom_partitioning_ptr(m: ir.Module):
"""
Removes custom_partitioning callback pointer from precompiled IR.
Python function pointers are not deterministic across executions.
"""
def _update_bc_attribute(op: ir.Operation) -> ir.WalkResult:
if (op.name == "stablehlo.custom_call" and
op.attributes["call_target_name"].value == "CustomSPMDPartitioning"):
op.attributes["backend_config"] = ir.StringAttr.get("REMOVED")
return ir.WalkResult.ADVANCE

m.operation.walk(_update_bc_attribute)
return m


def _serialize_ir(m: ir.Module) -> bytes:
output = io.BytesIO()
if config.remove_custom_partitioning_ptr_from_cache_key.value:
m = _remove_custom_partitioning_ptr(type_cast(ir.Module,
m.operation.clone()))
m.operation.write_bytecode(file=output)
return output.getvalue()

Expand Down
4 changes: 3 additions & 1 deletion jax/_src/compilation_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def put_executable_and_time(
cache.put(cache_key, executable_and_time)


def get_cache_key(module: ir.Module, devices: np.ndarray, compile_options,
def get_cache_key(module: ir.Module,
devices: np.ndarray,
compile_options,
backend) -> str:
return cache_key.get(module, devices, compile_options, backend,
"zstandard" if zstandard is not None else "zlib")
Expand Down
10 changes: 10 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,16 @@ def _update_jax_memories_thread_local(val):
'size to grow indefinitely.'),
)

remove_custom_partitioning_ptr_from_cache_key = bool_state(
name='jax_remove_custom_partitioning_ptr_from_cache_key',
default=False,
help=('If set to True, remove the custom partitioning pointer '
'present in the precompiled stableHLO before hashing '
'during cache key computation. This is a potentially '
'unsafe flag to set and only users who are sure of '
'what they are trying to achieve should set it.'),
)

default_dtype_bits = enum_state(
name='jax_default_dtype_bits',
enum_values=['32', '64'],
Expand Down
50 changes: 50 additions & 0 deletions tests/cache_key_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@

import hashlib
import os
import re
import sys
import unittest
from typing import cast as type_cast

import numpy as np

Expand All @@ -29,6 +31,11 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec as P
from jax._src.sharding_impls import NamedSharding
from jax._src.custom_partitioning import custom_partitioning


config.parse_flags_with_absl()
Expand Down Expand Up @@ -155,6 +162,49 @@ def test_different_computations(self):
cache_key.get(computation2, devices, compile_options, backend),
)

def test_custom_partitioning_ptr_removal(self):
def _partition(mesh, arg_shapes, result_shape):
arg_shardings = jax.tree.map(lambda x: x.sharding, arg_shapes)
result_shardings = NamedSharding(mesh, arg_shapes[0].sharding.spec)
return mesh, jax.numpy.add, result_shardings, arg_shardings

def _infer_sharding_from_operands(mesh, arg_shapes, result_shape):
return NamedSharding(mesh, arg_shapes[0].sharding.spec)

@custom_partitioning
def _cp_add(x, y):
return jax.numpy.add(x, y)

_cp_add.def_partition(
infer_sharding_from_operands=_infer_sharding_from_operands,
partition=_partition)

devices = np.asarray(jax.devices())
with Mesh(devices, ('x',)) as m:
computation = jax.jit(
_cp_add,
in_shardings=(NamedSharding(m, P('x')),
NamedSharding(m, P('x'))),
out_shardings=NamedSharding(m, P('x'))
).lower(
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
jax.ShapeDtypeStruct([1024], dtype=jax.numpy.float32),
).compiler_ir()
pattern = (
r'stablehlo\.custom_call @CustomSPMDPartitioning\('
r'(.*?)\) \{'
r'(.*?backend_config\s*=\s*"([^"]*)".*?)'
r'\}'
)
with config.remove_custom_partitioning_ptr_from_cache_key(True):
with computation.context:
updated_module = cache_key._remove_custom_partitioning_ptr(
type_cast(ir.Module, computation.operation.clone()))
bcs = [match[2] for
match in re.findall(pattern, str(updated_module), re.DOTALL)]
for bc in bcs:
self.assertEqual(bc, "REMOVED")

def test_different_device_assignment(self):
computation = jax.jit(lambda x, y: x + y).lower(1, 1).compiler_ir()
devices = np.array([[jax.local_devices()[0]]])
Expand Down

0 comments on commit 859188b

Please sign in to comment.