From 64e81899792e88029fb82d44e51c12372cca8ec4 Mon Sep 17 00:00:00 2001 From: Jakub Kuderski Date: Sat, 16 Nov 2024 14:32:40 -0500 Subject: [PATCH] [tuner] Use `ir.(Integer|Float)Type` for element types --- tuner/tuner/candidate_gen.py | 86 ++++++++++++----------- tuner/tuner/candidate_gen_test.py | 72 +++++++++++-------- tuner/tuner/common.py | 71 +++++++++---------- tuner/tuner/common_test.py | 70 ++++++++++++------- tuner/tuner/dispatch_constraints.py | 6 +- tuner/tuner/dispatch_constraints_test.py | 55 ++++++++------- tuner/tuner/dispatch_parser_test.py | 88 +++++++++++++----------- 7 files changed, 251 insertions(+), 197 deletions(-) diff --git a/tuner/tuner/candidate_gen.py b/tuner/tuner/candidate_gen.py index b50df12d5..ee331a2a6 100644 --- a/tuner/tuner/candidate_gen.py +++ b/tuner/tuner/candidate_gen.py @@ -517,48 +517,52 @@ def tune( with ir.Context() as ctx: tuner_context = TunerContext(ctx, tune_logger) - mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context) - # Save the input file as the first candidate. - with open(path.join(output, f"0.mlir"), "w") as f: - f.write(mlir_text) - - dispatch_tuner_registry = DispatchTunerRegistry() - dispatch_tuner_registry.register( - [ - MmtTuner(), - ConvTuner(), - ContractionTuner(lhs_dims, rhs_dims, tile_dims), - BatchMmtTuner(), - BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), - ] - ) + with parse_mlir(mlir_text, tuner_context) as mlir_module: + # Save the input file as the first candidate. + with open(path.join(output, f"0.mlir"), "w") as f: + f.write(mlir_text) + + dispatch_tuner_registry = DispatchTunerRegistry() + dispatch_tuner_registry.register( + [ + MmtTuner(), + ConvTuner(), + ContractionTuner(lhs_dims, rhs_dims, tile_dims), + BatchMmtTuner(), + BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims), + ] + ) - walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry) - - dispatch_tuner = walk_result.dispatch_tuner - assert dispatch_tuner, "No suitable dispatch tuner found" - problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) - tune_logger.debug(str(problem_size)) - configs = [] - for i, config in enumerate( - generate_solutions(tuner_context, problem_size, num_subgroups) - ): - if i >= limit: - break - tune_logger.info(f"Solution #{i+1}: {config}") - configs.append(config) - tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config) - - with open(path.join(output, f"{i+1}.mlir"), "w") as f: - f.write(tf_mlir.modified) - with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: - f.write(tf_mlir.embeddable) - - with open(path.join(output, "configs.pkl"), "wb") as file: - pickle.dump(configs, file) - - tune_logger.info(f"Generated {len(configs)} candidates") - tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") + walk_result: OpWalkResult = walk_mlir_op( + mlir_module, dispatch_tuner_registry + ) + + dispatch_tuner = walk_result.dispatch_tuner + assert dispatch_tuner, "No suitable dispatch tuner found" + problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template) + tune_logger.debug(str(problem_size)) + configs = [] + for i, config in enumerate( + generate_solutions(tune_logger, problem_size, num_subgroups) + ): + if i >= limit: + break + tune_logger.info(f"Solution #{i+1}: {config}") + configs.append(config) + tf_mlir = dispatch_tuner.apply_params( + problem_size, mlir_template, config + ) + + with open(path.join(output, f"{i+1}.mlir"), "w") as f: + f.write(tf_mlir.modified) + with open(path.join(output, f"{i+1}_config.mlir"), "w") as f: + f.write(tf_mlir.embeddable) + + with open(path.join(output, "configs.pkl"), "wb") as file: + pickle.dump(configs, file) + + tune_logger.info(f"Generated {len(configs)} candidates") + tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl") def main(): diff --git a/tuner/tuner/candidate_gen_test.py b/tuner/tuner/candidate_gen_test.py index 47e351fc7..36fb87cbb 100644 --- a/tuner/tuner/candidate_gen_test.py +++ b/tuner/tuner/candidate_gen_test.py @@ -10,17 +10,31 @@ import pytest +from typing import Generator + +from iree.compiler import ir # type: ignore + from . import candidate_gen from . import common +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + def remove_comments(mlir: str) -> str: return "\n".join( filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines()) ) -def test_apply_params_mmt() -> None: +def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: problem_size = common.ProblemSize( common.MatmulSize(M, N, K), - common.ShapedType([M, K], common.ElementType.f16), - common.ShapedType([N, K], common.ElementType.f16), - common.ShapedType([M, N], common.ElementType.f32), + common.ShapedType([M, K], tuner_ctx.type.f16), + common.ShapedType([N, K], tuner_ctx.type.f16), + common.ShapedType([M, N], tuner_ctx.type.f32), common.DispatchKind.mmt, ) tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config) @@ -73,7 +87,7 @@ def test_apply_params_mmt() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified -def test_apply_params_conv() -> None: +def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 16, subgroup_n_count = 16>", " None: problem_size = common.ProblemSize( common.MatmulSize(oh * ow, oc, fh * fw * ic), - common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16), - common.ShapedType([fh, fw, ic, oc], common.ElementType.f16), - common.ShapedType([n, oh, ow, oc], common.ElementType.f32), + common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16), + common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16), + common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32), common.DispatchKind.conv, ) tf_mlir = candidate_gen.ConvTuner().apply_params( @@ -130,7 +144,7 @@ def test_apply_params_conv() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_contract() -> None: +def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 2, subgroup_n_count = 2>}>", " None: tile_dims = "*mnk" problem_size = common.ProblemSize( common.MatmulSize(2048, 3840, 1280), - common.ShapedType([2, 1024, 1280], common.ElementType.f16), - common.ShapedType([3, 20, 64, 1280], common.ElementType.f16), - common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32), + common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32), common.DispatchKind.contraction, ) @@ -177,7 +191,7 @@ def test_apply_params_contract() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir -def test_apply_params_batch_matmul() -> None: +def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: tile_dims = "bmnk" problem_size = common.ProblemSize( common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], common.ElementType.f16), - common.ShapedType([64, 640, 320], common.ElementType.f16), - common.ShapedType([64, 968, 320], common.ElementType.f32), + common.ShapedType([64, 968, 640], tuner_ctx.type.f16), + common.ShapedType([64, 640, 320], tuner_ctx.type.f16), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, ) @@ -228,7 +242,7 @@ def test_apply_params_batch_matmul() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_batch_mmt_float() -> None: +def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.f16), - common.ShapedType([2, 640, 640], common.ElementType.f16), - common.ShapedType([2, 4096, 640], common.ElementType.f32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f16), + common.ShapedType([2, 640, 640], tuner_ctx.type.f16), + common.ShapedType([2, 4096, 640], tuner_ctx.type.f32), common.DispatchKind.batch_mmt, ) @@ -276,7 +290,7 @@ def test_apply_params_batch_mmt_float() -> None: assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified -def test_apply_params_batch_mmt_int() -> None: +def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([2, 640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), common.DispatchKind.batch_mmt, ) @@ -347,7 +361,7 @@ def test_apply_params_batch_mmt_int() -> None: assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable -def test_apply_params_broadcast_rhs_mmt() -> None: +def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_template = [ ", subgroup_m_count = 4, subgroup_n_count = 1>}>", " None: problem_size = common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), common.DispatchKind.broadcast_rhs_mmt, ) @@ -422,7 +436,7 @@ def test_apply_params_broadcast_rhs_mmt() -> None: assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable -def test_detect_broadcast_rhs_mmt() -> None: +def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None: mlir_lines = [ r"%18 = tensor.empty() : tensor<2x1024x10240xi32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>", diff --git a/tuner/tuner/common.py b/tuner/tuner/common.py index 7b295cdb0..aff49ef43 100644 --- a/tuner/tuner/common.py +++ b/tuner/tuner/common.py @@ -13,10 +13,27 @@ from iree.compiler import ir # type: ignore +class CommonTypes: + def __init__(self, ctx: ir.Context): + assert ctx + self.i1 = ir.IntegerType.get_signless(1, ctx) + self.i8 = ir.IntegerType.get_signless(8, ctx) + self.i16 = ir.IntegerType.get_signless(16, ctx) + self.i32 = ir.IntegerType.get_signless(32, ctx) + + self.f8E4M3FNUZ = ir.Float8E4M3FNUZType.get(ctx) + self.f8E5M2FNUZ = ir.Float8E5M2FNUZType.get(ctx) + self.f16 = ir.F16Type.get(ctx) + self.f32 = ir.F32Type.get(ctx) + + self.bf16 = ir.BF16Type.get(ctx) + + class TunerContext: def __init__(self, mlir_ctx: ir.Context, logger: logging.Logger): - self.mlir_ctx = mlir_ctx - self.logger = logger + self.mlir_ctx: ir.Context = mlir_ctx + self.logger: logging.Logger = logger + self.type: CommonTypes = CommonTypes(mlir_ctx) class DispatchKind(Enum): @@ -28,40 +45,17 @@ class DispatchKind(Enum): broadcast_rhs_mmt = 6 -class ElementType(Enum): - i8 = 1 - i32 = 2 - f8 = 3 - f16 = 4 - f32 = 5 - - @property - def bitwidth(self) -> int: - match self: - case ElementType.i8 | ElementType.f8: - return 8 - case ElementType.f16: - return 16 - case ElementType.i32 | ElementType.f32: - return 32 - case _: - assert False, "unhandled case" - - def __str__(self) -> str: - return self.name - - @dataclass class ShapedType: shape: list[int] - element_type: ElementType + element_type: ir.IntegerType | ir.FloatType def rank(self) -> int: return len(self.shape) @property def bitwidth(self) -> int: - return self.element_type.bitwidth + return self.element_type.width def __str__(self) -> str: dim_to_str = lambda dim: str(dim) if dim != -1 else "?" @@ -91,11 +85,11 @@ def MNK(self) -> tuple[int, int, int]: @dataclass class MfmaIntrinsic: - output_type: ElementType + output_type: ir.IntegerType | ir.FloatType m: int n: int k: int - input_type: ElementType + input_type: ir.IntegerType | ir.FloatType def __str__(self) -> str: input = str(self.input_type).upper() @@ -104,19 +98,27 @@ def __str__(self) -> str: @staticmethod def mfma_f32_16x16x16_f16(): - return MfmaIntrinsic(ElementType.f32, 16, 16, 16, ElementType.f16) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 16, 16, 16, f16) @staticmethod def mfma_f32_32x32x8_f16(): - return MfmaIntrinsic(ElementType.f32, 32, 32, 8, ElementType.f16) + f16 = ir.F16Type.get() + f32 = ir.F32Type.get() + return MfmaIntrinsic(f32, 32, 32, 8, f16) @staticmethod def mfma_i32_16x16x32_i8(): - return MfmaIntrinsic(ElementType.i32, 16, 16, 32, ElementType.i8) + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 16, 16, 32, i8) @staticmethod def mfma_i32_32x32x16_i8(): - return MfmaIntrinsic(ElementType.i32, 32, 32, 16, ElementType.i8) + i32 = ir.IntegerType.get_signless(32) + i8 = ir.IntegerType.get_signless(8) + return MfmaIntrinsic(i32, 32, 32, 16, i8) @staticmethod def all(): @@ -251,8 +253,7 @@ def parse_tensor_type(tensor_type: str) -> ShapedType: dims_and_elem = shape_str.split("x") dims = [int(x) for x in dims_and_elem[:-1]] elem = dims_and_elem[-1] - str_to_elem_ty = {x.name: x for x in ElementType} - return ShapedType(dims, str_to_elem_ty[elem]) + return ShapedType(dims, ir.Type.parse(elem)) @dataclass diff --git a/tuner/tuner/common_test.py b/tuner/tuner/common_test.py index 858d593c9..86a47c1c5 100644 --- a/tuner/tuner/common_test.py +++ b/tuner/tuner/common_test.py @@ -11,27 +11,47 @@ import pytest from . import common +from typing import Generator -def test_get_shaped_type_element_bitwidth() -> None: - assert common.ShapedType([1024, 2048], common.ElementType.i8).bitwidth == 8 - assert common.ShapedType([2048], common.ElementType.i32).bitwidth == 32 - assert common.ShapedType([2048, 512, 384], common.ElementType.f8).bitwidth == 8 - assert common.ShapedType([1, 1], common.ElementType.f16).bitwidth == 16 +from iree.compiler import ir # type: ignore -def test_get_shaped_type_to_str() -> None: - assert str(common.ShapedType([1024, 2048], common.ElementType.i8)) == "1024x2048xi8" - assert str(common.ShapedType([1024], common.ElementType.f32)) == "1024xf32" - assert str(common.ShapedType([1, 2, 3], common.ElementType.f16)) == "1x2x3xf16" - assert str(common.ShapedType([-1, 2, 3], common.ElementType.f16)) == "?x2x3xf16" +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) -def test_parse_tensor_type() -> None: + +@pytest.fixture +def mlir_ctx() -> Generator[ir.Context, None, None]: + with ir.Context() as ctx: + yield ctx + + +def test_get_shaped_type_element_bitwidth(tuner_ctx: common.TunerContext) -> None: + assert common.ShapedType([1024, 2048], tuner_ctx.type.i8).bitwidth == 8 + assert common.ShapedType([2048], tuner_ctx.type.i32).bitwidth == 32 + assert common.ShapedType([2048, 512, 384], tuner_ctx.type.f8E4M3FNUZ).bitwidth == 8 + assert common.ShapedType([1, 1], tuner_ctx.type.f16).bitwidth == 16 + + +def test_get_shaped_type_to_str(tuner_ctx: common.TunerContext) -> None: + assert str(common.ShapedType([1024, 2048], tuner_ctx.type.i8)) == "1024x2048xi8" + assert str(common.ShapedType([1024], tuner_ctx.type.f32)) == "1024xf32" + assert str(common.ShapedType([1, 2, 3], tuner_ctx.type.f16)) == "1x2x3xf16" + assert str(common.ShapedType([-1, 2, 3], tuner_ctx.type.f16)) == "?x2x3xf16" + + +def test_parse_tensor_type(tuner_ctx: common.TunerContext) -> None: assert common.parse_tensor_type("tensor<1x2x3xf32>") == common.ShapedType( - [1, 2, 3], common.ElementType.f32 + [1, 2, 3], tuner_ctx.type.f32 ) assert common.parse_tensor_type("tensor<123xi8>") == common.ShapedType( - [123], common.ElementType.i8 + [123], tuner_ctx.type.i8 ) @@ -59,7 +79,7 @@ def test_gpu_pipeline_options() -> None: ) -def test_get_pipeline_config() -> None: +def test_get_pipeline_config(mlir_ctx: ir.Context) -> None: config = common.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], @@ -85,18 +105,18 @@ def test_get_pipeline_config() -> None: ) -def test_mfma_intrinsic_to_str() -> None: +def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None: assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16" assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8" -def test_get_compatible_mfma_intrinsics() -> None: +def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), common.DispatchKind.mmt, ) ) == [ @@ -107,9 +127,9 @@ def test_get_compatible_mfma_intrinsics() -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.i8), - common.ShapedType([1280, 1280], common.ElementType.i8), - common.ShapedType([2048, 1280], common.ElementType.i32), + common.ShapedType([2048, 1280], tuner_ctx.type.i8), + common.ShapedType([1280, 1280], tuner_ctx.type.i8), + common.ShapedType([2048, 1280], tuner_ctx.type.i32), common.DispatchKind.mmt, ) ) == [ @@ -120,9 +140,9 @@ def test_get_compatible_mfma_intrinsics() -> None: assert common.get_compatible_mfma_intrinsics( common.ProblemSize( common.MatmulSize(968, 320, 640, 64), - common.ShapedType([64, 968, 640], common.ElementType.f32), - common.ShapedType([64, 640, 320], common.ElementType.f32), - common.ShapedType([64, 968, 320], common.ElementType.f32), + common.ShapedType([64, 968, 640], tuner_ctx.type.f32), + common.ShapedType([64, 640, 320], tuner_ctx.type.f32), + common.ShapedType([64, 968, 320], tuner_ctx.type.f32), common.DispatchKind.batch_matmul, ) ) == [ diff --git a/tuner/tuner/dispatch_constraints.py b/tuner/tuner/dispatch_constraints.py index ac46d8edd..edd7ccc38 100644 --- a/tuner/tuner/dispatch_constraints.py +++ b/tuner/tuner/dispatch_constraints.py @@ -130,10 +130,10 @@ def generate_constraints( def generate_solutions( - ctx: TunerContext, problem_size: ProblemSize, num_subgrups: int + logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int ) -> Iterator[Configuration]: M, N, K = problem_size.MNK - ctx.logger.info(f"{M},{N},{K}") + logger.info(f"{M},{N},{K}") m, n, k = z3.Int("m"), z3.Int("n"), z3.Int("k") subgroup_size = z3.Int("subgroup_size") intrinsic_mn = z3.Int("intrinsic_mn") @@ -170,7 +170,7 @@ def generate_solutions( waves_per_eu, ) solver.add(z3.simplify(z3.And(constraints))) - ctx.logger.debug(f"Initial constraints: {solver}") + logger.debug(f"Initial constraints: {solver}") i = 0 while solver.check() == z3.sat: model = solver.model() diff --git a/tuner/tuner/dispatch_constraints_test.py b/tuner/tuner/dispatch_constraints_test.py index 55f3a8c43..7e1a5c55d 100644 --- a/tuner/tuner/dispatch_constraints_test.py +++ b/tuner/tuner/dispatch_constraints_test.py @@ -11,32 +11,41 @@ import pytest import z3 # type: ignore -from logging import Logger -from unittest.mock import MagicMock +from typing import Generator + +from iree.compiler import ir # type: ignore from . import common from . import dispatch_constraints -def test_generate_solutions() -> None: +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_generate_solutions(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(2048, 3840, 1280) - lhs_type = common.ShapedType([2048, 1280], common.ElementType.f16) - rhs_type = common.ShapedType([3840, 1280], common.ElementType.f16) - res_type = common.ShapedType([2048, 3840], common.ElementType.f32) + lhs_type = common.ShapedType([2048, 1280], tuner_ctx.type.f16) + rhs_type = common.ShapedType([3840, 1280], tuner_ctx.type.f16) + res_type = common.ShapedType([2048, 3840], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) - logger: Logger = MagicMock(spec=Logger) - ctx = common.TunerContext(None, logger) - configs = dispatch_constraints.generate_solutions(ctx, problem_size, 4) + configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4) assert configs is not None -def test_calculate_shared_memory_usage_in_bytes() -> None: +def test_calculate_shared_memory_usage_in_bytes(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -47,7 +56,7 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: == 147456 ) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.i8) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i8) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -58,7 +67,7 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: == 81920 ) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.i32) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.i32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -70,11 +79,11 @@ def test_calculate_shared_memory_usage_in_bytes() -> None: ) -def test_generate_constraints_valid_input() -> None: +def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> None: matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) @@ -115,12 +124,12 @@ def test_generate_constraints_valid_input() -> None: assert solver.check() == z3.sat -def test_generate_constraints_invalid_input() -> None: +def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> None: # Define input parameters that should lead to unsatisfiable constraints matmul_size = common.MatmulSize(1024, 1024, 1024) - lhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - rhs_type = common.ShapedType([1024, 1024], common.ElementType.f16) - res_type = common.ShapedType([1024, 1024], common.ElementType.f32) + lhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + rhs_type = common.ShapedType([1024, 1024], tuner_ctx.type.f16) + res_type = common.ShapedType([1024, 1024], tuner_ctx.type.f32) problem_size = common.ProblemSize( matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt ) diff --git a/tuner/tuner/dispatch_parser_test.py b/tuner/tuner/dispatch_parser_test.py index bcdee240c..d473e5854 100644 --- a/tuner/tuner/dispatch_parser_test.py +++ b/tuner/tuner/dispatch_parser_test.py @@ -10,8 +10,7 @@ import pytest -from logging import Logger -from unittest.mock import MagicMock +from typing import Generator from iree.compiler import ir # type: ignore from iree.compiler.dialects import func # type: ignore @@ -20,7 +19,17 @@ from . import dispatch_parser -def test_get_mmt_tile_sizes() -> None: +@pytest.fixture +def tuner_ctx() -> Generator[common.TunerContext, None, None]: + from logging import Logger + from unittest.mock import MagicMock + + with ir.Context() as ctx: + logger: Logger = MagicMock(spec=Logger) + yield common.TunerContext(ctx, logger) + + +def test_get_mmt_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=0, workgroup_size=[], @@ -34,7 +43,7 @@ def test_get_mmt_tile_sizes() -> None: assert dispatch_parser.get_mmt_tile_sizes(config) == [128, 320, 32] -def test_get_conv_tile_sizes() -> None: +def test_get_conv_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=64, workgroup_size=[256, 1, 1], @@ -56,7 +65,7 @@ def test_get_conv_tile_sizes() -> None: ] -def test_get_contract_tile_sizes() -> None: +def test_get_contract_tile_sizes(tuner_ctx: common.TunerContext) -> None: config = dispatch_parser.Configuration( subgroup_size=32, workgroup_size=[16, 16, 1], @@ -77,7 +86,7 @@ def test_get_contract_tile_sizes() -> None: ] -def test_get_shapes_mmt() -> None: +def test_get_shapes_mmt(tuner_ctx: common.TunerContext) -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -86,14 +95,14 @@ def test_get_shapes_mmt() -> None: ] assert dispatch_parser.MmtParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), dispatch_parser.DispatchKind.mmt, ) -def test_get_shapes_conv() -> None: +def test_get_shapes_conv(tuner_ctx: common.TunerContext) -> None: template = [ r"%7 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%4 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", r"%8 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : vector<2xi64>, lowering_config = #iree_codegen.lowering_config, strides = dense<1> : vector<2xi64>} ins(%5, %6 : tensor<1x3x34x1280xf16>, tensor<3x3x1280x256xf16>) outs(%7 : tensor<1x1x32x256xf32>) -> tensor<1x1x32x256xf32>", @@ -101,14 +110,14 @@ def test_get_shapes_conv() -> None: ] assert dispatch_parser.ConvParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(32, 256, 11520), - common.ShapedType([1, 3, 34, 1280], common.ElementType.f16), - common.ShapedType([3, 3, 1280, 256], common.ElementType.f16), - common.ShapedType([1, 1, 32, 256], common.ElementType.f32), + common.ShapedType([1, 3, 34, 1280], tuner_ctx.type.f16), + common.ShapedType([3, 3, 1280, 256], tuner_ctx.type.f16), + common.ShapedType([1, 1, 32, 256], tuner_ctx.type.f32), dispatch_parser.DispatchKind.conv, ) -def test_get_shapes_contract() -> None: +def test_get_shapes_contract(tuner_ctx: common.TunerContext) -> None: template = [ r"%18 = tensor.empty() : tensor<2048x1280xf32>", r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%cst : f32) outs(%18 : tensor<2048x1280xf32>) -> tensor<2048x1280xf32>", @@ -119,14 +128,14 @@ def test_get_shapes_contract() -> None: template ) == common.ProblemSize( common.MatmulSize(2048, 1280, 1280), - common.ShapedType([2048, 1280], common.ElementType.f16), - common.ShapedType([1280, 1280], common.ElementType.f16), - common.ShapedType([2048, 1280], common.ElementType.f32), + common.ShapedType([2048, 1280], tuner_ctx.type.f16), + common.ShapedType([1280, 1280], tuner_ctx.type.f16), + common.ShapedType([2048, 1280], tuner_ctx.type.f32), dispatch_parser.DispatchKind.contraction, ) -def test_get_shapes_batch_matmul() -> None: +def test_get_shapes_batch_matmul(tuner_ctx: common.TunerContext) -> None: template = [ "%10 = linalg.fill ins(%cst : f32) outs(%7 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", "%11 = linalg.batch_matmul ins(%8, %9 : tensor<1x32x1024xf32>, tensor<1x1024x32xf32>) outs(%10 : tensor<1x32x32xf32>) -> tensor<1x32x32xf32>", @@ -136,14 +145,14 @@ def test_get_shapes_batch_matmul() -> None: template ) == common.ProblemSize( common.MatmulSize(32, 32, 1024, 1), - common.ShapedType([1, 32, 1024], common.ElementType.f32), - common.ShapedType([1, 1024, 32], common.ElementType.f32), - common.ShapedType([1, 32, 32], common.ElementType.f32), + common.ShapedType([1, 32, 1024], tuner_ctx.type.f32), + common.ShapedType([1, 1024, 32], tuner_ctx.type.f32), + common.ShapedType([1, 32, 32], tuner_ctx.type.f32), dispatch_parser.DispatchKind.batch_matmul, ) -def test_get_shapes_batch_mmt() -> None: +def test_get_shapes_batch_mmt(tuner_ctx: common.TunerContext) -> None: template = [ r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config} ins(%c0_i32 : i32) outs(%18 : tensor<2x4096x640xi32>) -> tensor<2x4096x640xi32>", r'%20 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} ins(%11, %12 : tensor<2x4096x640xi8>, tensor<2x640x640xi8>) outs(%19 : tensor<2x4096x640xi32>) attrs = {lowering_config = #iree_codegen.lowering_config} {', @@ -151,26 +160,23 @@ def test_get_shapes_batch_mmt() -> None: ] assert dispatch_parser.BatchMmtParser().get_shapes(template) == common.ProblemSize( common.MatmulSize(4096, 640, 640, 2), - common.ShapedType([2, 4096, 640], common.ElementType.i8), - common.ShapedType([2, 640, 640], common.ElementType.i8), - common.ShapedType([2, 4096, 640], common.ElementType.i32), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i8), + common.ShapedType([2, 640, 640], tuner_ctx.type.i8), + common.ShapedType([2, 4096, 640], tuner_ctx.type.i32), dispatch_parser.DispatchKind.batch_mmt, ) -def test_parse_mlir() -> None: - with ir.Context() as ctx: - mlir_str = r""" - builtin.module { - func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { - %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> - return %0 : tensor<4xf32> - } - } - """ - logger: Logger = MagicMock(spec=Logger) - tuner_context = common.TunerContext(ctx, logger) - mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_context) - assert mlir_module is not None - assert isinstance(mlir_module, ir.Module) - assert isinstance(mlir_module.body.operations[0], func.FuncOp) +def test_parse_mlir(tuner_ctx: common.TunerContext) -> None: + mlir_str = r""" + builtin.module { + func.func @simple_mul(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { + %0 = arith.mulf %arg0, %arg1 : tensor<4xf32> + return %0 : tensor<4xf32> + } + } +""" + mlir_module = dispatch_parser.parse_mlir(mlir_str, tuner_ctx) + assert mlir_module is not None + assert isinstance(mlir_module, ir.Module) + assert isinstance(mlir_module.body.operations[0], func.FuncOp)