From 170511d53e1b8ae55a6a824eced55291ca9929c2 Mon Sep 17 00:00:00 2001 From: George Necula Date: Sun, 3 Sep 2023 13:06:50 -0700 Subject: [PATCH] Align the custom_call implementation in mlir and hlo_helpers. PiperOrigin-RevId: 562385810 --- jax_triton/triton_lib.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index d4f2cee0..8e678d8f 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -419,13 +419,13 @@ def prune_configs(configs, named_args): call_proto = kernel_call.to_proto(serialized_metadata) return jaxlib.hlo_helpers.custom_call( call_target_name="triton_kernel_call", - out_types=out_types, + result_types=out_types, operands=array_args, backend_config=zlib.compress(call_proto), operand_layouts=avals_to_layouts(ctx.avals_in), result_layouts=avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), - ) + ).results mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)