From 4d443b93e2d9c97a12eff4ec303f4a75dafa068a Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 21 Oct 2024 14:17:50 +0100 Subject: [PATCH] Assume that `jaxlib.get_serialized_metadata` is always defined JAX now requires jaxlib >= 0.4.34, so we should as well. --- jax_triton/__init__.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index 064313c5..977bda50 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -24,7 +24,6 @@ "__version_info__", ] -import jaxlib from jax._src.lib import gpu_triton from jax_triton import utils from jax_triton.triton_lib import triton_call @@ -36,6 +35,7 @@ try: get_compute_capability = gpu_triton.get_compute_capability + get_serialized_metadata = gpu_triton.get_serialized_metadata except AttributeError: raise ImportError( "jax-triton requires JAX to be installed with GPU support. The " @@ -43,13 +43,7 @@ "instructions for installing a supported version:\n" "https://jax.readthedocs.io/en/latest/installation.html" ) - -if jaxlib.version.__version_info__ >= (0, 4, 14): - try: - get_serialized_metadata = gpu_triton.get_serialized_metadata - except AttributeError: - get_serialized_metadata = None +else: + del gpu_triton # Not part of the API. # trailer -del gpu_triton -del jaxlib