From e308f5def7a4f7581a24a682cde6a145b8ac8d16 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 26 Feb 2024 08:09:27 -0800 Subject: [PATCH] Moved Pallas GPU lowering registartion code into a separate submodule This makes the layout similar to the one we use in Pallas TPU. PiperOrigin-RevId: 610411536 --- jax_triton/triton_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index de4ec620..1cfe3789 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -159,7 +159,7 @@ def compile_ttir_to_ptx_inplace( cuda_options: cb.CUDAOptions, device: int = 0, device_type: str = "cuda", -) -> Tuple[str, str, int, int]: +) -> Tuple[str, str, int, int, Any]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) if cuda_options.debug: print(ttir)