From 4af298de84037374415f682c082564f697235d5a Mon Sep 17 00:00:00 2001 From: Christian Sigg Date: Thu, 12 Sep 2024 06:07:19 -0700 Subject: [PATCH] Integrate Triton up to [50d803cd](https://github.com/openai/triton/commits/50d803cdb4e68910ed663251100e168ea4d2519d PiperOrigin-RevId: 673813747 --- jax_triton/triton_lib.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index c20bcaa6..a5b0661f 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -391,7 +391,7 @@ def get_or_create_triton_kernel( if num_ctas > 1 and compute_capability < 90: raise ValueError("num_ctas > 1 unsupported before Hopper.") - signature = dict(enumerate(arg_dtypes)) + signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)} # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers # We assume that all arrays are aligned to 16 bytes, and Triton may use this # assumption, unless array args are include in the `do_not_specialize` list. @@ -403,9 +403,9 @@ def get_or_create_triton_kernel( args_for_specialization_attr[i] = v specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access - constants = {fn.arg_names.index(k): v for k, v in metaparams.items()} - constants.update({i: None for i, _, v in scalar_args if v is None}) - constants.update({i: 1 for i in specialization_attr.equal_to_1}) + constants = {k: v for k, v in metaparams.items()} + constants.update({k: None for _, k, v in scalar_args if v is None}) + constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1}) # Cache key should contain any parameter that can affect the compiler output. cache_key = (