diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 1c6ef5c8d..2a544ef55 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -61,8 +61,8 @@ def apply_configuration( expr5 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"") repl0 = f"" repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},' - repl2 = f'workgroup = [{", ".join(map(str, workgroup_sizes))}]' - repl3 = f'reduction = [{", ".join(map(str, reduction_sizes))}]' + repl2 = f"workgroup = {workgroup_sizes}" + repl3 = f"reduction = {reduction_sizes}" repl4 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}" repl5 = f'"amdgpu-waves-per-eu" = "{configuration.waves_per_eu}"' diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index a839ad4c4..702008f5e 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -125,14 +125,14 @@ def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: return None -def get_tilesize_workgroup(config: Configuration) -> list[int]: +def get_workgroup_tile_sizes(config: Configuration) -> list[int]: if "workgroup" in config.lowering_config.attributes: workgroup_attrs = config.lowering_config.attributes["workgroup"] return [attr.value for attr in workgroup_attrs] return [] -def get_tilesize_reduction(config: Configuration) -> list[int]: +def get_reduction_tile_sizes(config: Configuration) -> list[int]: if "reduction" in config.lowering_config.attributes: reduction_attrs = config.lowering_config.attributes["reduction"] return [attr.value for attr in reduction_attrs] @@ -163,26 +163,29 @@ def get_lowering_config( promoted_value = value match key: case "workgroup" | "reduction": - assert isinstance( - value, (list, ir.ArrayAttr) - ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, list): promoted_value = ir.ArrayAttr.get( [tuner_ctx.type.getI64(x) for x in value] ) + elif not isinstance(value, ir.ArrayAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case "subgroup_m_count" | "subgroup_n_count": - assert isinstance( - value, (int, tuner_ctx.type.i64) - ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, int): promoted_value = tuner_ctx.type.getI64(value) + elif not isinstance(value, tuner_ctx.type.i64): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case "mma_kind": - assert isinstance( - value, iree_gpu.MMAAttr - ), f"Unsupported type for key '{key}': {type(value).__name__}" + if not isinstance(value, iree_gpu.MMAAttr): + assert ( + False + ), f"Unsupported type for key '{key}': {type(value).__name__}" case _: - raise KeyError(f"Unhandled key in lowering configuration: {key}") - # Single assignment after the match. + assert False, f"Unhandled key in lowering configuration: {key}" + lowering_config_dict[key] = promoted_value lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 503ece345..ad63ba815 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_workgroup_sizes(configuration: Configuration): - return get_tilesize_workgroup(configuration) + return get_workgroup_tile_sizes(configuration) def get_mmt_reduction_sizes(configuration: Configuration): - return get_tilesize_reduction(configuration) + return get_reduction_tile_sizes(configuration) def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _k = get_tilesize_workgroup(configuration) + m, n, _k = get_workgroup_tile_sizes(configuration) workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _m, _n, k = get_tilesize_reduction(configuration) + _m, _n, k = get_reduction_tile_sizes(configuration) reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -58,11 +58,11 @@ def get_contract_reduction_sizes( def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + get_tilesize_workgroup(configuration) + return [1] + get_workgroup_tile_sizes(configuration) def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + get_tilesize_reduction(configuration) + return [0] + get_reduction_tile_sizes(configuration) class MlirRegex(Enum): @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ic = get_tilesize_workgroup(configuration) + ow, oc, _ic = get_workgroup_tile_sizes(configuration) return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = get_tilesize_reduction(configuration) + _ow, _oc, ic = get_reduction_tile_sizes(configuration) return [0, 0, 0, 0, 0, 0, ic]