Skip to content

Commit

Permalink
Integrate Triton up to [50d803cd](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
chsigg authored and The jax_triton Authors committed Sep 16, 2024
1 parent 949f07b commit 923bc6a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ def get_or_create_triton_kernel(
if num_ctas > 1 and compute_capability < 90:
raise ValueError("num_ctas > 1 unsupported before Hopper.")

signature = dict(enumerate(arg_dtypes))
signature = {fn.arg_names[i]: v for i, v in enumerate(arg_dtypes)}
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
# We assume that all arrays are aligned to 16 bytes, and Triton may use this
# assumption, unless array args are include in the `do_not_specialize` list.
Expand All @@ -402,9 +402,9 @@ def get_or_create_triton_kernel(
args_for_specialization_attr[i] = v
specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access

constants = {fn.arg_names.index(k): v for k, v in metaparams.items()}
constants.update({i: None for i, _, v in scalar_args if v is None})
constants.update({i: 1 for i in specialization_attr.equal_to_1})
constants = {k: v for k, v in metaparams.items()}
constants.update({k: None for _, k, v in scalar_args if v is None})
constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1})

# Cache key should contain any parameter that can affect the compiler output.
cache_key = (
Expand Down

0 comments on commit 923bc6a

Please sign in to comment.