From 589900d51cc1848052c50120659b829e36656b92 Mon Sep 17 00:00:00 2001 From: Bangtian Liu Date: Thu, 28 Nov 2024 16:57:05 -0600 Subject: [PATCH] [tuner]: use lowering config binding Signed-off-by: Bangtian Liu --- tuner/tuner/candidate_gen.py | 40 +++++--- tuner/tuner/candidate_gen_test.py | 140 ++++++++++++++++++++++------ tuner/tuner/common.py | 34 ++++--- tuner/tuner/common_test.py | 21 ++++- tuner/tuner/dispatch_constraints.py | 36 +++++-- tuner/tuner/dispatch_parser.py | 8 +- tuner/tuner/dispatch_parser_test.py | 57 ++++++++--- 7 files changed, 255 insertions(+), 81 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index f09e08888..01a6ed2aa 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -42,6 +42,9 @@ def apply_configuration( template: list[str], configuration: Configuration, tile_sizes: list[int] ) -> str: + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -52,7 +55,7 @@ def apply_configuration( expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]") expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>") expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") - repl0 = f"" + repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]' repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" @@ -116,6 +119,9 @@ def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -131,8 +137,8 @@ def get_transform_function_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %matmul, %config : !transform.any_op, !transform.any_param @@ -186,6 +192,9 @@ def get_transform_function_conv( output = f"tensor<{dynamic_batch_output_ty}>" tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -204,8 +213,8 @@ def get_transform_function_conv( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %conv, %config : !transform.any_op, !transform.any_param @@ -245,6 +254,9 @@ def get_transform_function_broadcast_rhs_mmt( configuration: Configuration, ) -> str: tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -265,8 +277,8 @@ def get_transform_function_broadcast_rhs_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -329,6 +341,9 @@ def get_transform_function_batch_mmt( configuration: Configuration, ) -> str: tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration))) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -345,8 +360,8 @@ def get_transform_function_batch_mmt( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %generic, %config : !transform.any_op, !transform.any_param @@ -395,6 +410,9 @@ def get_transform_function_batch_matmul( tile_sizes = ", ".join( map(str, get_contract_tile_sizes(configuration, tile_dims)) ) + intrinsic = configuration.intrinsic + subgroup_m_count = configuration.subgroup_m_count + subgroup_n_count = configuration.subgroup_n_count wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -413,8 +431,8 @@ def get_transform_function_batch_matmul( translation_info = #iree_codegen.translation_info + intrinsic = {intrinsic}, + subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}> {extra_config}}}> > -> !transform.any_param transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 19b6e1fe7..98055eb3d 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -48,13 +48,25 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=16, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[8, 8, 8], - subgroup_m_count=16, - subgroup_n_count=16, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( prefetch_shared_memory=True ), @@ -104,13 +116,24 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 464), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get( reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get( iree_gpu.ReorderWorkgroupsStrategy.Transpose @@ -171,13 +194,25 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 480), + ir.IntegerAttr.get(tuner_ctx.type.i32, 384), + ir.IntegerAttr.get(tuner_ctx.type.i32, 32), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[480, 384, 32], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -220,13 +255,26 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 416), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[416, 320, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -272,13 +320,25 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) @@ -322,13 +382,25 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) @@ -395,13 +467,25 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 64), + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=64, workgroup_size=[128, 2, 1], - intrinsic=mma_attr, - tile_sizes=[128, 64, 128], - subgroup_m_count=2, - subgroup_n_count=2, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=4, ) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 80c755aa7..ba79f3be6 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -105,26 +105,34 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool: return list(filter(is_comptible, mma_intrinsics)) -class ReorderWorkgroupsStrategy(Enum): - NONE = 0 - SWIZZLE = 1 - TRANSPOSE = 2 - - def __str__(self) -> str: - return self.name.title() - - @dataclass class Configuration: subgroup_size: int workgroup_size: list[int] - intrinsic: iree_gpu.MMAAttr - tile_sizes: list[int] - subgroup_m_count: int - subgroup_n_count: int + lowering_config: iree_gpu.LoweringConfigAttr gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int + @property + def intrinsic(self) -> iree_gpu.MMAAttr: + return self.lowering_config.attributes["mma_kind"] + + @property + def tilesize_workgroup(self) -> list[int]: + return [attr.value for attr in self.lowering_config.attributes["workgroup"]] + + @property + def tilesize_reduction(self) -> list[int]: + return [attr.value for attr in self.lowering_config.attributes["reduction"]] + + @property + def subgroup_m_count(self) -> int: + return self.lowering_config.attributes["subgroup_m_count"].value + + @property + def subgroup_n_count(self) -> int: + return self.lowering_config.attributes["subgroup_n_count"].value + def get_pipeline_config(configuration: Configuration) -> str: extra_config = "" diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 73d3f04e3..bbb241980 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -73,16 +73,27 @@ def test_gpu_pipeline_options(tuner_ctx: common.TunerContext) -> None: ) -def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: +def test_get_pipeline_config(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, ) diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index cdfb1bd50..5a775bdcd 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -10,8 +10,10 @@ import z3 # type: ignore from typing import Iterator +from iree.compiler import ir # type: ignore from iree.compiler.dialects import iree_gpu # type: ignore +from iree.compiler.dialects import iree_codegen # type: ignore from .common import * @@ -217,15 +219,15 @@ def generate_solutions( ) solver.add(z3.simplify(z3.And(constraints))) logger.debug(f"Initial constraints: {solver}") + + int_type = ir.IntegerType.get_signless(32) + i = 0 while solver.check() == z3.sat: model = solver.model() lookup = lambda var: model[var].as_long() - - config = Configuration( - lookup(subgroup_size), - [lookup(wg_x), lookup(wg_y), lookup(wg_z)], - getMMAAttr( + lowering_config_dict = { + "mma_kind": getMMAAttr( problem_size.res_type.element_type, lookup(intrinsic_mn), lookup(intrinsic_mn), @@ -233,9 +235,27 @@ def generate_solutions( problem_size.lhs_type.element_type, problem_size.rhs_type.element_type, ), - [lookup(m), lookup(n), lookup(k)], - lookup(sg_m_cnt), - lookup(sg_n_cnt), + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(int_type, lookup(m)), + ir.IntegerAttr.get(int_type, lookup(n)), + ir.IntegerAttr.get(int_type, lookup(k)), + ] + ), + "reduction": ir.ArrayAttr.get( + [] + ), # placeholder now to be consistent with iree + "subgroup_m_count": ir.IntegerAttr.get(int_type, lookup(sg_m_cnt)), + "subgroup_n_count": ir.IntegerAttr.get(int_type, lookup(sg_n_cnt)), + } + + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) + lowering_config = iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) + + config = Configuration( + lookup(subgroup_size), + [lookup(wg_x), lookup(wg_y), lookup(wg_z)], + lowering_config, iree_gpu.PipelineOptionsAttr.get(), lookup(waves_per_eu), ) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index c4b4b9ad5..421e2c7ef 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,11 +21,11 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_tile_sizes(configuration: Configuration): - return configuration.tile_sizes + return configuration.tilesize_workgroup def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> list[int]: - m, n, k = configuration.tile_sizes + m, n, k = configuration.tilesize_workgroup tile_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "m": @@ -38,7 +38,7 @@ def get_contract_tile_sizes(configuration: Configuration, tile_dims: str) -> lis def get_batch_mmt_tile_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tile_sizes + return [1] + configuration.tilesize_workgroup class MlirRegex(Enum): @@ -141,7 +141,7 @@ def supports(self, op_name: str) -> bool: return "conv_2d_nhwc_hwcf" in op_name def get_conv_tile_sizes(self, configuration: Configuration) -> list[int]: - m, n, k = configuration.tile_sizes + m, n, k = configuration.tilesize_workgroup batch = 1 fh = 1 fw = 1 diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index 529559f83..8318ca9c1 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -42,13 +42,24 @@ def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 128), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 32), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 0), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], - intrinsic=mma_attr, - tile_sizes=[128, 320, 32], - subgroup_m_count=0, - subgroup_n_count=0, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=0, ) @@ -58,13 +69,24 @@ def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 464), + ir.IntegerAttr.get(tuner_ctx.type.i32, 320), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], - intrinsic=mma_attr, - tile_sizes=[464, 320, 16], - subgroup_m_count=1, - subgroup_n_count=4, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=1, ) @@ -82,13 +104,24 @@ def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16 mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic) + lowering_config_dict = { + "mma_kind": mma_attr, + "workgroup": ir.ArrayAttr.get( + [ + ir.IntegerAttr.get(tuner_ctx.type.i32, 4), + ir.IntegerAttr.get(tuner_ctx.type.i32, 8), + ir.IntegerAttr.get(tuner_ctx.type.i32, 16), + ] + ), + "reduction": ir.ArrayAttr.get([]), + "subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + "subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1), + } + lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], - intrinsic=mma_attr, - tile_sizes=[4, 8, 16], - subgroup_m_count=1, - subgroup_n_count=1, + lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs), gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(), waves_per_eu=2, )