Skip to content

Commit

Permalink
Merge pull request #305 from superbobry:maint
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 688152751
  • Loading branch information
Google-ML-Automation committed Oct 21, 2024
2 parents cd49678 + 4d443b9 commit 354c15a
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions jax_triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
"__version_info__",
]

import jaxlib
from jax._src.lib import gpu_triton
from jax_triton import utils
from jax_triton.triton_lib import triton_call
Expand All @@ -36,20 +35,15 @@

try:
get_compute_capability = gpu_triton.get_compute_capability
get_serialized_metadata = gpu_triton.get_serialized_metadata
except AttributeError:
raise ImportError(
"jax-triton requires JAX to be installed with GPU support. The "
"installation page on the JAX documentation website includes "
"instructions for installing a supported version:\n"
"https://jax.readthedocs.io/en/latest/installation.html"
)

if jaxlib.version.__version_info__ >= (0, 4, 14):
try:
get_serialized_metadata = gpu_triton.get_serialized_metadata
except AttributeError:
get_serialized_metadata = None
else:
del gpu_triton # Not part of the API.

# trailer
del gpu_triton
del jaxlib

0 comments on commit 354c15a

Please sign in to comment.