diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index bfba86a..064313c 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -34,7 +34,16 @@ from jax_triton.version import __version__ from jax_triton.version import __version_info__ -get_compute_capability = gpu_triton.get_compute_capability +try: + get_compute_capability = gpu_triton.get_compute_capability +except AttributeError: + raise ImportError( + "jax-triton requires JAX to be installed with GPU support. The " + "installation page on the JAX documentation website includes " + "instructions for installing a supported version:\n" + "https://jax.readthedocs.io/en/latest/installation.html" + ) + if jaxlib.version.__version_info__ >= (0, 4, 14): try: get_serialized_metadata = gpu_triton.get_serialized_metadata