diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 746572d6..fbf784d1 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -303,6 +303,7 @@ def triton_kernel_call_lowering( fn, scalar_args, name, + custom_call_target_name, out_shapes, grid, num_warps, @@ -484,7 +485,7 @@ def prune_configs(configs, named_args): else: call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( - call_target_name="triton_kernel_call", + call_target_name=custom_call_target_name, result_types=out_types, operands=array_args, backend_config=zlib.compress(call_proto), @@ -514,6 +515,7 @@ def triton_call( out_shape: Union[ShapeDtype, Sequence[ShapeDtype]], grid: GridOrLambda, name: str = "", + custom_call_target_name: str = "triton_kernel_call", num_warps: Optional[int] = None, num_stages: Optional[int] = None, num_ctas: int = 1, @@ -636,6 +638,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: fn=kernel, scalar_args=tuple(scalar_args), name=name, + custom_call_target_name=custom_call_target_name, out_shapes=tuple(flat_out_shapes), grid=grid, num_warps=num_warps,