From 321a19e901ee318a43345485a47e711944011c77 Mon Sep 17 00:00:00 2001 From: Chris Jones Date: Tue, 22 Aug 2023 08:47:17 -0700 Subject: [PATCH] [jax_triton] Always use "triton_kernel_call" as custom call target name. Now there is a `name` field in the serialized proto, there is an alternative way to differentiate kernel calls. PiperOrigin-RevId: 559124166 --- jax_triton/triton_lib.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index f736a4c0..ceeb7e97 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -30,7 +30,6 @@ import jax.dlpack from jax.interpreters import mlir from jax.interpreters import xla -from jax.lib import xla_client as xc import jax.numpy as jnp import numpy as np @@ -259,7 +258,6 @@ def triton_kernel_call_lowering( fn, scalar_args, name, - call_name, out_shapes, grid, num_warps, @@ -404,7 +402,7 @@ def prune_configs(configs, named_args): for input_idx, output_idx in input_output_aliases ) kernel_call = triton_kernel_call_lib.TritonAutotunedKernelCall( - f"{fn.fn.__name__} ({call_name=}) {named_scalar_args}", + f"{kernel_call_name} ({fn.fn.__name__}) {named_scalar_args}", [(call, str(config)) for call, config in zip(kernel_calls, configs)], input_output_aliases_with_sizes, ) @@ -420,7 +418,7 @@ def prune_configs(configs, named_args): else: call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( - call_target_name=call_name, + call_target_name="triton_kernel_call", out_types=out_types, operands=array_args, backend_config=zlib.compress(call_proto), @@ -450,7 +448,6 @@ def triton_call( out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], grid: GridOrLambda, name: str = "", - call_name: str = "triton_kernel_call", # TODO(cjfj): Remove this. num_warps: int = 4, num_stages: int = 2, input_output_aliases: Optional[Dict[int, int]] = None, @@ -545,9 +542,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: raise ValueError( "`triton_call` is only available when `triton` is installed." ) - xc.register_custom_call_target( - call_name, triton_kernel_call_lib.get_custom_call(), platform="CUDA" - ) out_shape = tree_util.tree_map( lambda a: jax.ShapeDtypeStruct(a.shape, a.dtype), out_shape) flat_args, _ = tree_util.tree_flatten(args) @@ -572,7 +566,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: fn=kernel, scalar_args=tuple(scalar_args), name=name, - call_name=call_name, out_shapes=tuple(flat_out_shapes), grid=grid, num_warps=num_warps,