Skip to content

Commit

Permalink
Improve the error message when jaxlib is installed without GPU support.
Browse files Browse the repository at this point in the history
Currently jax-triton fails to import with an `AttributeError` when
jaxlib isn't installed with CUDA or ROCM support. This error message
should provide more useful feedback.
  • Loading branch information
dfm committed Sep 16, 2024
1 parent 973e106 commit a5dc021
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@
from jax_triton.version import __version__
from jax_triton.version import __version_info__

get_compute_capability = gpu_triton.get_compute_capability
try:
get_compute_capability = gpu_triton.get_compute_capability
except AttributeError:
raise ImportError(
"jax-triton requires JAX to be installed with GPU support. The "
"installation page on the JAX documentation website includes "
"instructions for installing a supported version:\n"
"https://jax.readthedocs.io/en/latest/installation.html"
)

if jaxlib.version.__version_info__ >= (0, 4, 14):
try:
get_serialized_metadata = gpu_triton.get_serialized_metadata
Expand Down

0 comments on commit a5dc021

Please sign in to comment.