Skip to content

Commit

Permalink
[jax_triton] Always use "triton_kernel_call" as custom call target name.
Browse files Browse the repository at this point in the history
Now there is a `name` field in the serialized proto, there is an alternative way to differentiate kernel calls.

PiperOrigin-RevId: 559124166
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Aug 22, 2023
1 parent ad44ac8 commit 321a19e
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import jax.dlpack
from jax.interpreters import mlir
from jax.interpreters import xla
from jax.lib import xla_client as xc
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -259,7 +258,6 @@ def triton_kernel_call_lowering(
fn,
scalar_args,
name,
call_name,
out_shapes,
grid,
num_warps,
Expand Down Expand Up @@ -404,7 +402,7 @@ def prune_configs(configs, named_args):
for input_idx, output_idx in input_output_aliases
)
kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall(
f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}",
f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}",
[(call, str(config)) for call, config in zip(kernel_calls, configs)],
input_output_aliases_with_sizes,
)
Expand All @@ -420,7 +418,7 @@ def prune_configs(configs, named_args):
else:
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name=call_name,
call_target_name="triton_kernel_call",
out_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
Expand Down Expand Up @@ -450,7 +448,6 @@ def triton_call(
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
grid: GridOrLambda,
name: str = "",
call_name: str = "triton_kernel_call", # TODO(cjfj): Remove this.
num_warps: int = 4,
num_stages: int = 2,
input_output_aliases: Optional[Dict[int, int]] = None,
Expand Down Expand Up @@ -545,9 +542,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
raise ValueError(
"`triton_call` is only available when `triton` is installed."
)
xc.register_custom_call_target(
call_name, triton_kernel_call_lib.get_custom_call(), platform="CUDA"
)
out_shape = tree_util.tree_map(
lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape)
flat_args, _ = tree_util.tree_flatten(args)
Expand All @@ -572,7 +566,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
call_name=call_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
Expand Down

0 comments on commit 321a19e

Please sign in to comment.