Skip to content

Commit

Permalink
[triton] Pass cluster_dims to TritonKernel and use cuLaunchKernel if …
Browse files Browse the repository at this point in the history
…size <= 1

PiperOrigin-RevId: 599809560
  • Loading branch information
The jax_triton Authors committed Jan 19, 2024
1 parent b2ae828 commit 4af0ecb
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ def get_triton_type(obj: Any) -> str:
)


Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]

triton_kernel_call_p = jax.core.Primitive("triton_kernel_call")
triton_kernel_call_p.multiple_results = True
triton_kernel_call_p.def_impl(
Expand Down Expand Up @@ -240,7 +237,8 @@ def compile_ttir_to_ptx_inplace(
if cuda_options.debug:
print(ptx)
name = ptx_get_kernel_name(ptx)
return ptx, name, shared_mem_bytes, compute_capability
cluster_dims = metadata["cluster_dims"]
return ptx, name, shared_mem_bytes, compute_capability, cluster_dims


_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?
Expand Down Expand Up @@ -312,7 +310,6 @@ def get_or_create_triton_kernel(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
cluster_dims=(1, 1, 1),
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
optimize_epilogue=False,
Expand All @@ -338,7 +335,7 @@ def get_or_create_triton_kernel(
)

ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
ptx, kernel_name, shared_mem_bytes, compute_capability, cluster_dims = (
compile_ttir_to_ptx_inplace(
module,
context,
Expand All @@ -350,7 +347,13 @@ def get_or_create_triton_kernel(
)

kernel = triton_kernel_call_lib.TritonKernel(
kernel_name, num_warps, shared_mem_bytes, ptx, ttir, compute_capability
kernel_name,
num_warps,
shared_mem_bytes,
ptx,
ttir,
compute_capability,
*cluster_dims,
)

_COMPILED_KERNEL_CACHE[cache_key] = kernel
Expand Down Expand Up @@ -580,7 +583,7 @@ def triton_call(
custom_call_target_name: str = "triton_kernel_call",
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_ctas: int = 1,
num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple.
enable_fp_fusion: bool = True,
enable_warp_specialization: bool = False,
enable_persistent: bool = False,
Expand Down Expand Up @@ -663,6 +666,8 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
indices, for outputs that should be zeroed before the kernel is launched.
num_warps: The number of warps used to execute the Triton kernel.
num_stages: The number of stages emitted by the Triton compiler.
num_ctas: The size of thread blocks per cluster to be used on GPUs with
compute capabilities >= 9.0. It must be less or equal to 8.
debug: Prints out intermediate IRs if True for debugging purposes.
serialized_metadata: Arbitrary metadata that will be added into the
serialized kernel call.
Expand Down

0 comments on commit 4af0ecb

Please sign in to comment.