From 8b712edb6dcafe5bca8795fd6de4cd315d302f4e Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 27 Aug 2024 09:23:31 -0700 Subject: [PATCH] Get StableHLO version from compatibility requirements in JAX and PJRT. PiperOrigin-RevId: 668017510 --- jax/_src/export/_export.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/jax/_src/export/_export.py b/jax/_src/export/_export.py index 54defa0e9c54..65f3d2852348 100644 --- a/jax/_src/export/_export.py +++ b/jax/_src/export/_export.py @@ -681,8 +681,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported: def _module_to_bytecode(module: ir.Module) -> bytes: mlir_str = mlir.module_to_bytecode(module) # `target_version` is used to manage situations when a StableHLO producer - # (in this case, jax2tf) and a StableHLO consumer were built using - # different versions of StableHLO. + # and a StableHLO consumer were built using different versions of StableHLO. # # Each StableHLO version `producer_version` has a compatibility window, # i.e. range of versions [`consumer_version_min`, `consumer_version_max`], @@ -691,12 +690,19 @@ def _module_to_bytecode(module: ir.Module) -> bytes: # See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md # for the exact extent of these compatibility guarantees. # - # `hlo.get_minimum_version()` returns `consumer_version_min` - # for the current version of StableHLO. We are using it here to maximize - # forward compatibility, i.e. to maximize how far into the past we can go - # and still have the payloads produced by `serialize_portable_artifact` - # compatible with potential consumers from the past. - target_version = hlo.get_minimum_version() + # `hlo.get_version_from_compatibility_requirement(WEEK_4)` returns a version + # of StableHLO >= 4w old. This allows new StableHLO features to be used after + # ~4w and be compatible with any consumer that is updated on at least a + # monthly cadence. + # + # Note that this does not verify any JAX custom calls, which are only + # guaranteed 3w of forward compatibility, and only prevents use of new + # StableHLO features from failing on older hardware. + if hlo.get_api_version() < 9: + target_version = hlo.get_minimum_version() + else: + target_version = hlo.get_version_from_compatibility_requirement( + hlo.StablehloCompatibilityRequirement.WEEK_4) module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore mlir_str, target_version) return module_serialized