Skip to content

Commit

Permalink
Integrate Triton up to [657c49e1](https://github.com/openai/triton/co…
Browse files Browse the repository at this point in the history
  • Loading branch information
The jax_triton Authors committed Feb 16, 2024
1 parent 8f07074 commit 708d3e8
Showing 1 changed file with 0 additions and 18 deletions.
18 changes: 0 additions & 18 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,6 @@ def get_or_create_triton_kernel(
num_stages,
num_ctas,
enable_fp_fusion,
enable_warp_specialization,
enable_persistent,
metaparams,
dump: bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
Expand Down Expand Up @@ -270,8 +268,6 @@ def get_or_create_triton_kernel(
num_stages,
num_ctas,
enable_fp_fusion,
enable_warp_specialization,
enable_persistent,
)
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

Expand All @@ -285,8 +281,6 @@ def get_or_create_triton_kernel(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
optimize_epilogue=False,
debug=dump,
enable_fp_fusion=enable_fp_fusion,
Expand Down Expand Up @@ -348,8 +342,6 @@ def triton_kernel_call_lowering(
num_stages,
num_ctas,
enable_fp_fusion,
enable_warp_specialization,
enable_persistent,
input_output_aliases,
zeroed_outputs,
debug,
Expand Down Expand Up @@ -406,9 +398,7 @@ def prune_configs(configs, named_args):
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
enable_warp_specialization=enable_warp_specialization,
)
config.enable_persistent = enable_persistent
configs = [config]

if isinstance(fn, autotuner.Heuristics):
Expand Down Expand Up @@ -449,8 +439,6 @@ def prune_configs(configs, named_args):
num_warps=config.num_warps,
num_stages=config.num_stages,
num_ctas=config.num_ctas,
enable_warp_specialization=config.enable_warp_specialization,
enable_persistent=config.enable_persistent,
grid=config_grid,
zeroed_params_with_sizes=tuple(zeroed_params_with_sizes.items()),
)
Expand All @@ -466,8 +454,6 @@ def prune_configs(configs, named_args):
num_stages=params["num_stages"],
num_ctas=params["num_ctas"],
enable_fp_fusion=enable_fp_fusion,
enable_warp_specialization=params["enable_warp_specialization"],
enable_persistent=params["enable_persistent"],
metaparams=dict(params["metaparams"]),
dump=debug,
)
Expand Down Expand Up @@ -559,8 +545,6 @@ def triton_call(
num_stages: Optional[int] = None,
num_ctas: int = 1, # TODO(giorgioa): Add support for dimensions tuple.
enable_fp_fusion: bool = True,
enable_warp_specialization: bool = False,
enable_persistent: bool = False,
input_output_aliases: Optional[Dict[int, int]] = None,
zeroed_outputs: Union[
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
Expand Down Expand Up @@ -687,8 +671,6 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
num_stages=num_stages,
num_ctas=num_ctas,
enable_fp_fusion=enable_fp_fusion,
enable_warp_specialization=enable_warp_specialization,
enable_persistent=enable_persistent,
input_output_aliases=tuple(input_output_aliases.items()),
zeroed_outputs=zeroed_outputs,
debug=debug,
Expand Down

0 comments on commit 708d3e8

Please sign in to comment.