From b544afd7401815e4547b1939a8f14de0061bc55b Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 25 Nov 2024 12:56:22 -0600 Subject: [PATCH 1/5] [tuner]: remove MfmaIntrinsic and use iree_gpu.MMAAttr instead Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 13 +++--- tuner/tuner/candidate_gen_test.py | 29 +++++++++--- tuner/tuner/common.py | 72 +++++------------------------ tuner/tuner/common_test.py | 25 +++++----- tuner/tuner/dispatch_constraints.py | 37 +++++++++++++-- tuner/tuner/dispatch_parser_test.py | 15 ++++-- 6 files changed, 95 insertions(+), 96 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 38696e6db..f09e08888 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -52,7 +52,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -119,7 +119,6 @@ def get_transform_function_mmt( wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -132,7 +131,7 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -205,7 +204,7 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 36fb87cbb..ed744d033 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -13,6 +13,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import candidate_gen from . import common @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, f"MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index b6e31768e..8baa1594e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]: return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) -@dataclass -class MfmaIntrinsic: - output_type: ir.IntegerType | ir.FloatType - m: int - n: int - k: int - input_type: ir.IntegerType | ir.FloatType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 16, 16, 16, f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 32, 32, 8, f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 16, 16, 32, i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 32, 32, 16, i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - def get_compatible_mfma_intrinsics( problem_size: ProblemSize, mma_intrinsics: list[iree_gpu.MMAIntrinsic], -) -> list[MfmaIntrinsic]: - available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] - - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: +) -> list[iree_gpu.MMAIntrinsic]: + def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + a_type, b_type, c_type = mma_attr.abc_element_types + if problem_size.res_type.element_type != c_type: return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: + if ( + problem_size.lhs_type.element_type != a_type + or problem_size.rhs_type.element_type != b_type + ): return False - - if str(intrinsic) not in available_mma_intrinsics: - return False - return True - return list(filter(is_compatible, MfmaIntrinsic.all())) + return list(filter(is_comptible, mma_intrinsics)) class ReorderWorkgroupsStrategy(Enum): @@ -197,7 +147,7 @@ def __str__(self) -> str: class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: MfmaIntrinsic + intrinsic: iree_gpu.MMAAttr tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 297ac95a2..fc44d80ff 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest common_test.py """ import pytest @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: ) -def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: - assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" - assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" - - def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], ) == [ - common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ] assert common.get_compatible_mfma_intrinsics( @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert ( diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 85039a1e8..13bf1a5ce 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -27,8 +27,14 @@ def get_mfma_intrinsic_constraints( assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" return z3.Or( *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics + z3.And( + intrinsic_m == mma_attr.mnk_shape[0], + intrinsic_n == mma_attr.mnk_shape[1], + intrinsic_k == mma_attr.mnk_shape[2], + ) + for mma_attr in ( + iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics + ) ) ) @@ -134,6 +140,30 @@ def generate_constraints( return constraints +def getMMAAttr( + output_type: ir.IntegerType | ir.FloatType, + m: int, + n: int, + k: int, + lhs_type: ir.IntegerType | ir.FloatType, + rhs_type: ir.IntegerType | ir.FloatType, +) -> iree_gpu.MMAAttr: + mma_str = "" + if lhs_type == rhs_type: + input = str(lhs_type).upper() + output = str(output_type).upper() + mma_str = f"MFMA_{output}_{m}x{n}x{k}_{input}" + else: + lhs = str(lhs_type).upper() + rhs = str(rhs_type).upper() + output = str(output_type).upper() + mma_str = f"MFMA_{output}_{m}x{n}x{k}_{lhs}_{rhs}" + + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, mma_str) + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + return mma_attr + + def generate_solutions( logger: logging.Logger, problem_size: ProblemSize, @@ -188,12 +218,13 @@ def generate_solutions( config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( + getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index d3a99806f..cca9f3606 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest dispatch_parser_test.py """ import pytest @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_parser @@ -39,10 +40,12 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -53,10 +56,12 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -75,10 +80,12 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, From 03238b2dddddbaadde21317f6e452ef17e0329bf Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 25 Nov 2024 15:34:29 -0600 Subject: [PATCH 2/5] [tunner]: address comments Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen_test.py | 14 +++++----- tuner/tuner/common.py | 2 +- tuner/tuner/common_test.py | 2 +- tuner/tuner/dispatch_constraints.py | 40 +++++++++++++++++------------ tuner/tuner/dispatch_parser_test.py | 6 ++--- 5 files changed, 35 insertions(+), 29 deletions(-) diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index ed744d033..d81278e8c 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -46,7 +46,7 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, f"MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, @@ -100,7 +100,7 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -166,7 +166,7 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -215,7 +215,7 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_32x32x8_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -267,7 +267,7 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -317,7 +317,7 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, @@ -390,7 +390,7 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_I32_32x32x16_I8") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 8baa1594e..45ae48c22 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -90,7 +90,7 @@ def get_compatible_mfma_intrinsics( mma_intrinsics: list[iree_gpu.MMAIntrinsic], ) -> list[iree_gpu.MMAIntrinsic]: def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma a_type, b_type, c_type = mma_attr.abc_element_types if problem_size.res_type.element_type != c_type: return False diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index fc44d80ff..ea0a4573d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -72,7 +72,7 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 13bf1a5ce..3e8dc6c9f 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -28,13 +28,14 @@ def get_mfma_intrinsic_constraints( return z3.Or( *( z3.And( - intrinsic_m == mma_attr.mnk_shape[0], - intrinsic_n == mma_attr.mnk_shape[1], - intrinsic_k == mma_attr.mnk_shape[2], + intrinsic_m == mnk[0], + intrinsic_n == mnk[1], + intrinsic_k == mnk[2], ) for mma_attr in ( iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics ) + for mnk in [mma_attr.mnk_shape] ) ) @@ -148,20 +149,25 @@ def getMMAAttr( lhs_type: ir.IntegerType | ir.FloatType, rhs_type: ir.IntegerType | ir.FloatType, ) -> iree_gpu.MMAAttr: - mma_str = "" - if lhs_type == rhs_type: - input = str(lhs_type).upper() - output = str(output_type).upper() - mma_str = f"MFMA_{output}_{m}x{n}x{k}_{input}" - else: - lhs = str(lhs_type).upper() - rhs = str(rhs_type).upper() - output = str(output_type).upper() - mma_str = f"MFMA_{output}_{m}x{n}x{k}_{lhs}_{rhs}" - - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, mma_str) - mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) - return mma_attr + for mma_intrinsic in iree_gpu.MMAIntrinsic: + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + a_type, b_type, c_type = mma_attr.abc_element_types + mnk = mma_attr.mnk_shape + if ( + a_type == lhs_type + and b_type == rhs_type + and c_type == output_type + and m == mnk[0] + and n == mnk[1] + and k == mnk[2] + ): + return mma_attr + # If no matching intrinsic is found, raise an exception + raise ValueError( + f"No matching MMA intrinsic found for " + f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, " + f"m={m}, n={n}, k={k}." + ) def generate_solutions( diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index cca9f3606..fb10b04bc 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -40,7 +40,7 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, @@ -56,7 +56,7 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, @@ -80,7 +80,7 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: - mma_intrinsic = getattr(iree_gpu.MMAIntrinsic, "MFMA_F32_16x16x16_F16") + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32, From 385f484291a908156a92577929bbffbb3a5ef352 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Mon, 25 Nov 2024 21:10:13 -0600 Subject: [PATCH 3/5] [tuner] format code again Signed-off-by: Bangtian Liu --- tuner/tuner/dispatch_constraints.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 3e8dc6c9f..93bfa0ff3 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -25,6 +25,10 @@ def get_mfma_intrinsic_constraints( ) -> z3.BoolRef: compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + + mma_attrs = [iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics] + mnk_shapes = [mma_attr.mnk_shape for mma_attr in mma_attrs] + return z3.Or( *( z3.And( @@ -35,7 +39,7 @@ def get_mfma_intrinsic_constraints( for mma_attr in ( iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics ) - for mnk in [mma_attr.mnk_shape] + for mnk in mnk_shapes ) ) @@ -150,7 +154,7 @@ def getMMAAttr( rhs_type: ir.IntegerType | ir.FloatType, ) -> iree_gpu.MMAAttr: for mma_intrinsic in iree_gpu.MMAIntrinsic: - mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) a_type, b_type, c_type = mma_attr.abc_element_types mnk = mma_attr.mnk_shape if ( From 579af35a5d90579e0920d4100db7a8c3bfc08b56 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 26 Nov 2024 08:27:14 -0600 Subject: [PATCH 4/5] [tuner] save the code Signed-off-by: Bangtian Liu --- tuner/tuner/dispatch_constraints.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 93bfa0ff3..2c7f73c85 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -36,9 +36,6 @@ def get_mfma_intrinsic_constraints( intrinsic_n == mnk[1], intrinsic_k == mnk[2], ) - for mma_attr in ( - iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics - ) for mnk in mnk_shapes ) ) From fe1af17e83ba77e5471266a423ef048350f3621d Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Tue, 26 Nov 2024 09:16:35 -0600 Subject: [PATCH 5/5] [tuner] simplify for loop Signed-off-by: Bangtian Liu --- tuner/tuner/dispatch_constraints.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 2c7f73c85..f16b4a241 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -32,11 +32,11 @@ def get_mfma_intrinsic_constraints( return z3.Or( *( z3.And( - intrinsic_m == mnk[0], - intrinsic_n == mnk[1], - intrinsic_k == mnk[2], + intrinsic_m == m, + intrinsic_n == n, + intrinsic_k == k, ) - for mnk in mnk_shapes + for m, n, k in mnk_shapes ) )