diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 38696e6db..f09e08888 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -52,7 +52,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f", subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -119,7 +119,6 @@ def get_transform_function_mmt( wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) - return f""" transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{ %mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op @@ -132,7 +131,7 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -205,7 +204,7 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info, + intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}> {extra_config}}}> > -> !transform.any_param diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 36fb87cbb..d81278e8c 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -13,6 +13,7 @@ from typing import Generator from iree.compiler import ir # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import candidate_gen from . import common @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: M, N, K = 2048, 1280, 1280 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[8, 8, 8], subgroup_m_count=16, subgroup_n_count=16, @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640 + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.contraction, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[480, 384, 32], subgroup_m_count=1, subgroup_n_count=4, @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_matmul, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + intrinsic=mma_attr, tile_sizes=[416, 320, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.batch_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: common.DispatchKind.broadcast_rhs_mmt, ) + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + intrinsic=mma_attr, tile_sizes=[128, 64, 128], subgroup_m_count=2, subgroup_n_count=2, diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index b6e31768e..45ae48c22 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]: return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K) -@dataclass -class MfmaIntrinsic: - output_type: ir.IntegerType | ir.FloatType - m: int - n: int - k: int - input_type: ir.IntegerType | ir.FloatType - - def __str__(self) -> str: - input = str(self.input_type).upper() - output = str(self.output_type).upper() - return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}" - - @staticmethod - def mfma_f32_16x16x16_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 16, 16, 16, f16) - - @staticmethod - def mfma_f32_32x32x8_f16(): - f16 = ir.F16Type.get() - f32 = ir.F32Type.get() - return MfmaIntrinsic(f32, 32, 32, 8, f16) - - @staticmethod - def mfma_i32_16x16x32_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 16, 16, 32, i8) - - @staticmethod - def mfma_i32_32x32x16_i8(): - i32 = ir.IntegerType.get_signless(32) - i8 = ir.IntegerType.get_signless(8) - return MfmaIntrinsic(i32, 32, 32, 16, i8) - - @staticmethod - def all(): - return [ - MfmaIntrinsic.mfma_f32_16x16x16_f16(), - MfmaIntrinsic.mfma_f32_32x32x8_f16(), - MfmaIntrinsic.mfma_i32_16x16x32_i8(), - MfmaIntrinsic.mfma_i32_32x32x16_i8(), - ] - - def get_compatible_mfma_intrinsics( problem_size: ProblemSize, mma_intrinsics: list[iree_gpu.MMAIntrinsic], -) -> list[MfmaIntrinsic]: - available_mma_intrinsics = [str(mma) for mma in mma_intrinsics] - - def is_compatible(intrinsic: MfmaIntrinsic) -> bool: - if problem_size.res_type.element_type != intrinsic.output_type: +) -> list[iree_gpu.MMAIntrinsic]: + def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: + mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma + a_type, b_type, c_type = mma_attr.abc_element_types + if problem_size.res_type.element_type != c_type: return False if problem_size.dispatch_kind != DispatchKind.batch_matmul: - if problem_size.lhs_type.element_type != intrinsic.input_type: - return False - if problem_size.rhs_type.element_type != intrinsic.input_type: + if ( + problem_size.lhs_type.element_type != a_type + or problem_size.rhs_type.element_type != b_type + ): return False - - if str(intrinsic) not in available_mma_intrinsics: - return False - return True - return list(filter(is_compatible, MfmaIntrinsic.all())) + return list(filter(is_comptible, mma_intrinsics)) class ReorderWorkgroupsStrategy(Enum): @@ -197,7 +147,7 @@ def __str__(self) -> str: class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: MfmaIntrinsic + intrinsic: iree_gpu.MMAAttr tile_sizes: list[int] subgroup_m_count: int subgroup_n_count: int diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 297ac95a2..ea0a4573d 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest common_test.py """ import pytest @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None: def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1, @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: ) -def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: - assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" - assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" - - def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ], ) == [ - common.MfmaIntrinsic.mfma_i32_16x16x32_i8(), - common.MfmaIntrinsic.mfma_i32_32x32x16_i8(), + iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8, + iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8, ] assert common.get_compatible_mfma_intrinsics( @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16, + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert common.get_compatible_mfma_intrinsics( @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ], ) == [ - common.MfmaIntrinsic.mfma_f32_32x32x8_f16(), + iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16, ] assert ( diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index 85039a1e8..f16b4a241 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -25,10 +25,18 @@ def get_mfma_intrinsic_constraints( ) -> z3.BoolRef: compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics) assert len(compatible_intrinsics) > 0, "No compatible intrinsics found" + + mma_attrs = [iree_gpu.MMAAttr.get(mfma) for mfma in compatible_intrinsics] + mnk_shapes = [mma_attr.mnk_shape for mma_attr in mma_attrs] + return z3.Or( *( - z3.And(intrinsic_m == mfma.m, intrinsic_n == mfma.n, intrinsic_k == mfma.k) - for mfma in compatible_intrinsics + z3.And( + intrinsic_m == m, + intrinsic_n == n, + intrinsic_k == k, + ) + for m, n, k in mnk_shapes ) ) @@ -134,6 +142,35 @@ def generate_constraints( return constraints +def getMMAAttr( + output_type: ir.IntegerType | ir.FloatType, + m: int, + n: int, + k: int, + lhs_type: ir.IntegerType | ir.FloatType, + rhs_type: ir.IntegerType | ir.FloatType, +) -> iree_gpu.MMAAttr: + for mma_intrinsic in iree_gpu.MMAIntrinsic: + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + a_type, b_type, c_type = mma_attr.abc_element_types + mnk = mma_attr.mnk_shape + if ( + a_type == lhs_type + and b_type == rhs_type + and c_type == output_type + and m == mnk[0] + and n == mnk[1] + and k == mnk[2] + ): + return mma_attr + # If no matching intrinsic is found, raise an exception + raise ValueError( + f"No matching MMA intrinsic found for " + f"output_type={output_type}, lhs_type={lhs_type}, rhs_type={rhs_type}, " + f"m={m}, n={n}, k={k}." + ) + + def generate_solutions( logger: logging.Logger, problem_size: ProblemSize, @@ -188,12 +225,13 @@ def generate_solutions( config = Configuration( lookup(subgroup_size), [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - MfmaIntrinsic( + getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), lookup(intrinsic_k), problem_size.lhs_type.element_type, + problem_size.rhs_type.element_type, ), [lookup(m), lookup(n), lookup(k)], lookup(sg_m_cnt), diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index d3a99806f..fb10b04bc 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -5,7 +5,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception """ -Usage: python -m pytest candidate_gen_test.py +Usage: python -m pytest dispatch_parser_test.py """ import pytest @@ -14,6 +14,7 @@ from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore +from iree.compiler.dialects import iree_gpu # type: ignore from . import common from . import dispatch_parser @@ -39,10 +40,12 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[128, 320, 32], subgroup_m_count=0, subgroup_n_count=0, @@ -53,10 +56,12 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[464, 320, 16], subgroup_m_count=1, subgroup_n_count=4, @@ -75,10 +80,12 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: + mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 + mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(), + intrinsic=mma_attr, tile_sizes=[4, 8, 16], subgroup_m_count=1, subgroup_n_count=1,