Skip to content

Commit

Permalink
Get StableHLO version from compatibility requirements in JAX and PJRT.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 668017510
  • Loading branch information
GleasonK authored and jax authors committed Aug 28, 2024
1 parent f0a7266 commit f29c905
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,8 +681,7 @@ def _get_exported_vjp(exp_primal: Exported) -> Exported:
def _module_to_bytecode(module: ir.Module) -> bytes:
mlir_str = mlir.module_to_bytecode(module)
# `target_version` is used to manage situations when a StableHLO producer
# (in this case, jax2tf) and a StableHLO consumer were built using
# different versions of StableHLO.
# and a StableHLO consumer were built using different versions of StableHLO.
#
# Each StableHLO version `producer_version` has a compatibility window,
# i.e. range of versions [`consumer_version_min`, `consumer_version_max`],
Expand All @@ -691,12 +690,19 @@ def _module_to_bytecode(module: ir.Module) -> bytes:
# See https://github.com/openxla/stablehlo/blob/main/docs/compatibility.md
# for the exact extent of these compatibility guarantees.
#
# `hlo.get_minimum_version()` returns `consumer_version_min`
# for the current version of StableHLO. We are using it here to maximize
# forward compatibility, i.e. to maximize how far into the past we can go
# and still have the payloads produced by `serialize_portable_artifact`
# compatible with potential consumers from the past.
target_version = hlo.get_minimum_version()
# `hlo.get_version_from_compatibility_requirement(WEEK_4)` returns a version
# of StableHLO >= 4w old. This allows new StableHLO features to be used after
# ~4w and be compatible with any consumer that is updated on at least a
# monthly cadence.
#
# Note that this does not verify any JAX custom calls, which are only
# guaranteed 3w of forward compatibility, and only prevents use of new
# StableHLO features from failing on older hardware.
if hlo.get_api_version() < 9:
target_version = hlo.get_minimum_version()
else:
target_version = hlo.get_version_from_compatibility_requirement(
hlo.StablehloCompatibilityRequirement.WEEK_4)
module_serialized = xla_client._xla.mlir.serialize_portable_artifact( # type: ignore
mlir_str, target_version)
return module_serialized
Expand Down

0 comments on commit f29c905

Please sign in to comment.