diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index 6f90891e8..1c6ef5c8d 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -45,9 +45,9 @@ def apply_configuration( workgroup_sizes: list[int], reduction_sizes: list[int], ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) tune_logger.info(f"Applying: {configuration}") expr0 = re.compile( r", subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>" @@ -125,9 +125,9 @@ class MmtTuner(DispatchTuner, MmtParser): def get_transform_function_mmt( self, problem_size: ProblemSize, functionName: str, configuration: Configuration ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -199,9 +199,9 @@ def get_transform_function_conv( reduction_sizes = ", ".join( map(str, self.get_conv_reduction_sizes(configuration)) ) - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -269,9 +269,9 @@ def get_transform_function_broadcast_rhs_mmt( reduction_sizes = ", ".join( map(str, get_batch_mmt_reduction_sizes(configuration)) ) - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -359,9 +359,9 @@ def get_transform_function_batch_mmt( functionName: str, configuration: Configuration, ) -> str: - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) @@ -428,9 +428,9 @@ def get_transform_function_batch_matmul( input1 = f"tensor<{problem_size.rhs_type}>" output = f"tensor<{problem_size.res_type}>" - intrinsic = configuration.intrinsic() - subgroup_m_count = configuration.subgroup_m_count() - subgroup_n_count = configuration.subgroup_n_count() + intrinsic = get_intrinsic(configuration) + subgroup_m_count = get_subgroup_m_count(configuration) + subgroup_n_count = get_subgroup_n_count(configuration) wg_x, wg_y, wg_z = configuration.workgroup_size extra_config = get_pipeline_config(configuration) diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 3253dd077..a839ad4c4 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -118,34 +118,39 @@ class Configuration: gpu_pipeline_options: iree_gpu.PipelineOptionsAttr waves_per_eu: int - def intrinsic(self) -> Optional[iree_gpu.MMAAttr]: - if "mma_kind" in self.lowering_config.attributes: - return self.lowering_config.attributes["mma_kind"] - return None - - def tilesize_workgroup(self) -> list[int]: - if "workgroup" in self.lowering_config.attributes: - workgroup_attrs = self.lowering_config.attributes["workgroup"] - return [attr.value for attr in workgroup_attrs] - return [] - - def tilesize_reduction(self) -> list[int]: - if "reduction" in self.lowering_config.attributes: - reduction_attrs = self.lowering_config.attributes["reduction"] - return [attr.value for attr in reduction_attrs] - return [] - - def subgroup_m_count(self) -> Optional[int]: - if "subgroup_m_count" in self.lowering_config.attributes: - attr = self.lowering_config.attributes["subgroup_m_count"] - return attr.value - return None - - def subgroup_n_count(self) -> Optional[int]: - if "subgroup_n_count" in self.lowering_config.attributes: - attr = self.lowering_config.attributes["subgroup_n_count"] - return attr.value - return None + +def get_intrinsic(config: Configuration) -> Optional[iree_gpu.MMAAttr]: + if "mma_kind" in config.lowering_config.attributes: + return config.lowering_config.attributes["mma_kind"] + return None + + +def get_tilesize_workgroup(config: Configuration) -> list[int]: + if "workgroup" in config.lowering_config.attributes: + workgroup_attrs = config.lowering_config.attributes["workgroup"] + return [attr.value for attr in workgroup_attrs] + return [] + + +def get_tilesize_reduction(config: Configuration) -> list[int]: + if "reduction" in config.lowering_config.attributes: + reduction_attrs = config.lowering_config.attributes["reduction"] + return [attr.value for attr in reduction_attrs] + return [] + + +def get_subgroup_m_count(config: Configuration) -> Optional[int]: + if "subgroup_m_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_m_count"] + return attr.value + return None + + +def get_subgroup_n_count(config: Configuration) -> Optional[int]: + if "subgroup_n_count" in config.lowering_config.attributes: + attr = config.lowering_config.attributes["subgroup_n_count"] + return attr.value + return None def get_lowering_config( @@ -154,36 +159,31 @@ def get_lowering_config( ) -> iree_gpu.LoweringConfigAttr: lowering_config_dict: dict[str, Any] = {} for key, value in kwargs.items(): + # A local variable to hold the transformed value. + promoted_value = value match key: case "workgroup" | "reduction": + assert isinstance( + value, (list, ir.ArrayAttr) + ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, list): - lowering_config_dict[key] = ir.ArrayAttr.get( + promoted_value = ir.ArrayAttr.get( [tuner_ctx.type.getI64(x) for x in value] ) - elif isinstance(value, ir.ArrayAttr): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) case "subgroup_m_count" | "subgroup_n_count": + assert isinstance( + value, (int, tuner_ctx.type.i64) + ), f"Unsupported type for key '{key}': {type(value).__name__}" if isinstance(value, int): - lowering_config_dict[key] = tuner_ctx.type.getI64(value) - elif isinstance(value, tuner_ctx.type.i64): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) + promoted_value = tuner_ctx.type.getI64(value) case "mma_kind": - if isinstance(value, iree_gpu.MMAAttr): - lowering_config_dict[key] = value - else: - raise TypeError( - f"Unsupported type for key '{key}': {type(value).__name__}" - ) + assert isinstance( + value, iree_gpu.MMAAttr + ), f"Unsupported type for key '{key}': {type(value).__name__}" case _: raise KeyError(f"Unhandled key in lowering configuration: {key}") + # Single assignment after the match. + lowering_config_dict[key] = promoted_value lowering_config_attrs = ir.DictAttr.get(lowering_config_dict) return iree_gpu.LoweringConfigAttr.get(lowering_config_attrs) diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 1dfb6ff7b..f13aed3d7 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -215,6 +215,6 @@ def test_get_lowering_config(tuner_ctx: common.TunerContext) -> None: waves_per_eu=2, ) - assert config.intrinsic() is None - assert config.subgroup_m_count() == 1 - assert config.subgroup_n_count() == 1 + assert common.get_intrinsic(config) is None + assert common.get_subgroup_m_count(config) == 1 + assert common.get_subgroup_n_count(config) == 1 diff --git a/tuner/tuner/dispatch_parser.py b/tuner/tuner/dispatch_parser.py index 0c5209ccd..503ece345 100644 --- a/tuner/tuner/dispatch_parser.py +++ b/tuner/tuner/dispatch_parser.py @@ -21,17 +21,17 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: def get_mmt_workgroup_sizes(configuration: Configuration): - return configuration.tilesize_workgroup() + return get_tilesize_workgroup(configuration) def get_mmt_reduction_sizes(configuration: Configuration): - return configuration.tilesize_reduction() + return get_tilesize_reduction(configuration) def get_contract_workgroup_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - m, n, _ = configuration.tilesize_workgroup() + m, n, _k = get_tilesize_workgroup(configuration) workgroup_size = [1] * len(tile_dims) for idx, dim in enumerate(tile_dims): @@ -48,7 +48,7 @@ def get_contract_workgroup_sizes( def get_contract_reduction_sizes( configuration: Configuration, tile_dims: str ) -> list[int]: - _, _, k = configuration.tilesize_reduction() + _m, _n, k = get_tilesize_reduction(configuration) reduction_size = [0] * len(tile_dims) for idx, dim in enumerate(tile_dims): if dim == "k": @@ -58,11 +58,11 @@ def get_contract_reduction_sizes( def get_batch_mmt_workgroup_sizes(configuration: Configuration) -> list[int]: - return [1] + configuration.tilesize_workgroup() + return [1] + get_tilesize_workgroup(configuration) def get_batch_mmt_reduction_sizes(configuration: Configuration) -> list[int]: - return [0] + configuration.tilesize_reduction() + return [0] + get_tilesize_reduction(configuration) class MlirRegex(Enum): @@ -171,12 +171,12 @@ def get_conv_workgroup_sizes(self, configuration: Configuration) -> list[int]: oh = 1 - ow, oc, _ic = configuration.tilesize_workgroup() + ow, oc, _ic = get_tilesize_workgroup(configuration) return [batch, oh, ow, oc, fh, fw, 0] def get_conv_reduction_sizes(self, configuration: Configuration) -> list[int]: - _ow, _oc, ic = configuration.tilesize_reduction() + _ow, _oc, ic = get_tilesize_reduction(configuration) return [0, 0, 0, 0, 0, 0, ic]