From a5dc021da012f201f310f945d6a9ef94ce764e9e Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Mon, 16 Sep 2024 12:58:21 -0400 Subject: [PATCH] Improve the error message when jaxlib is installed without GPU support. Currently jax-triton fails to import with an `AttributeError` when jaxlib isn't installed with CUDA or ROCM support. This error message should provide more useful feedback. --- jax_triton/__init__.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index bfba86a0..064313c5 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -34,7 +34,16 @@ from jax_triton.version import __version__ from jax_triton.version import __version_info__ -get_compute_capability = gpu_triton.get_compute_capability +try: + get_compute_capability = gpu_triton.get_compute_capability +except AttributeError: + raise ImportError( + "jax-triton requires JAX to be installed with GPU support. The " + "installation page on the JAX documentation website includes " + "instructions for installing a supported version:\n" + "https://jax.readthedocs.io/en/latest/installation.html" + ) + if jaxlib.version.__version_info__ >= (0, 4, 14): try: get_serialized_metadata = gpu_triton.get_serialized_metadata