Skip to content

Commit

Permalink
[jax_triton] Allow user to override the custom call target name.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 580245167
  • Loading branch information
chr1sj0nes authored and The jax_triton Authors committed Nov 7, 2023
1 parent b81596b commit 68c6262
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def triton_kernel_call_lowering(
fn,
scalar_args,
name,
custom_call_target_name,
out_shapes,
grid,
num_warps,
Expand Down Expand Up @@ -484,7 +485,7 @@ def prune_configs(configs, named_args):
else:
call_proto = kernel_call.to_proto(serialized_metadata)
return jaxlib.hlo_helpers.custom_call(
call_target_name="triton_kernel_call",
call_target_name=custom_call_target_name,
result_types=out_types,
operands=array_args,
backend_config=zlib.compress(call_proto),
Expand Down Expand Up @@ -514,6 +515,7 @@ def triton_call(
out_shape: Union[ShapeDtype, Sequence[ShapeDtype]],
grid: GridOrLambda,
name: str = "",
custom_call_target_name: str = "triton_kernel_call",
num_warps: Optional[int] = None,
num_stages: Optional[int] = None,
num_ctas: int = 1,
Expand Down Expand Up @@ -636,6 +638,7 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
fn=kernel,
scalar_args=tuple(scalar_args),
name=name,
custom_call_target_name=custom_call_target_name,
out_shapes=tuple(flat_out_shapes),
grid=grid,
num_warps=num_warps,
Expand Down

0 comments on commit 68c6262

Please sign in to comment.