Skip to content

Commit

Permalink
[tuner] Use ir.(Integer|Float)Type for element types
Browse files Browse the repository at this point in the history
  • Loading branch information
kuhar committed Nov 16, 2024
1 parent 5ccfc87 commit 64e8189
Show file tree
Hide file tree
Showing 7 changed files with 251 additions and 197 deletions.
86 changes: 45 additions & 41 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,48 +517,52 @@ def tune(

with ir.Context() as ctx:
tuner_context = TunerContext(ctx, tune_logger)
mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context)
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)
with parse_mlir(mlir_text, tuner_context) as mlir_module:
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)

walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tuner_context, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")
walk_result: OpWalkResult = walk_mlir_op(
mlir_module, dispatch_tuner_registry
)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(
problem_size, mlir_template, config
)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")


def main():
Expand Down
72 changes: 43 additions & 29 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,31 @@

import pytest

from typing import Generator

from iree.compiler import ir # type: ignore

from . import candidate_gen
from . import common


@pytest.fixture
def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)


def remove_comments(mlir: str) -> str:
return "\n".join(
filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines())
)


def test_apply_params_mmt() -> None:
def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 16, subgroup_n_count = 16>",
"<LLVMGPUVectorDistribute workgroup_size = [16, 16] subgroup_size = 16,",
Expand All @@ -44,9 +58,9 @@ def test_apply_params_mmt() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(M, N, K),
common.ShapedType([M, K], common.ElementType.f16),
common.ShapedType([N, K], common.ElementType.f16),
common.ShapedType([M, N], common.ElementType.f32),
common.ShapedType([M, K], tuner_ctx.type.f16),
common.ShapedType([N, K], tuner_ctx.type.f16),
common.ShapedType([M, N], tuner_ctx.type.f32),
common.DispatchKind.mmt,
)
tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config)
Expand All @@ -73,7 +87,7 @@ def test_apply_params_mmt() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified


def test_apply_params_conv() -> None:
def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 16, subgroup_n_count = 16>",
"<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64,",
Expand All @@ -98,9 +112,9 @@ def test_apply_params_conv() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(oh * ow, oc, fh * fw * ic),
common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16),
common.ShapedType([fh, fw, ic, oc], common.ElementType.f16),
common.ShapedType([n, oh, ow, oc], common.ElementType.f32),
common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16),
common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16),
common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32),
common.DispatchKind.conv,
)
tf_mlir = candidate_gen.ConvTuner().apply_params(
Expand Down Expand Up @@ -130,7 +144,7 @@ def test_apply_params_conv() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_contract() -> None:
def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>",
"<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64,",
Expand All @@ -141,9 +155,9 @@ def test_apply_params_contract() -> None:
tile_dims = "*mnk"
problem_size = common.ProblemSize(
common.MatmulSize(2048, 3840, 1280),
common.ShapedType([2, 1024, 1280], common.ElementType.f16),
common.ShapedType([3, 20, 64, 1280], common.ElementType.f16),
common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32),
common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16),
common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16),
common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32),
common.DispatchKind.contraction,
)

Expand Down Expand Up @@ -177,7 +191,7 @@ def test_apply_params_contract() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir


def test_apply_params_batch_matmul() -> None:
def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -188,9 +202,9 @@ def test_apply_params_batch_matmul() -> None:
tile_dims = "bmnk"
problem_size = common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], common.ElementType.f16),
common.ShapedType([64, 640, 320], common.ElementType.f16),
common.ShapedType([64, 968, 320], common.ElementType.f32),
common.ShapedType([64, 968, 640], tuner_ctx.type.f16),
common.ShapedType([64, 640, 320], tuner_ctx.type.f16),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
)

Expand Down Expand Up @@ -228,7 +242,7 @@ def test_apply_params_batch_matmul() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_batch_mmt_float() -> None:
def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -238,9 +252,9 @@ def test_apply_params_batch_mmt_float() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.f16),
common.ShapedType([2, 640, 640], common.ElementType.f16),
common.ShapedType([2, 4096, 640], common.ElementType.f32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.f16),
common.ShapedType([2, 640, 640], tuner_ctx.type.f16),
common.ShapedType([2, 4096, 640], tuner_ctx.type.f32),
common.DispatchKind.batch_mmt,
)

Expand Down Expand Up @@ -276,7 +290,7 @@ def test_apply_params_batch_mmt_float() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_batch_mmt_int() -> None:
def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -286,9 +300,9 @@ def test_apply_params_batch_mmt_int() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.i8),
common.ShapedType([2, 640, 640], common.ElementType.i8),
common.ShapedType([2, 4096, 640], common.ElementType.i32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i8),
common.ShapedType([2, 640, 640], tuner_ctx.type.i8),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i32),
common.DispatchKind.batch_mmt,
)

Expand Down Expand Up @@ -347,7 +361,7 @@ def test_apply_params_batch_mmt_int() -> None:
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable


def test_apply_params_broadcast_rhs_mmt() -> None:
def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -357,9 +371,9 @@ def test_apply_params_broadcast_rhs_mmt() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.i8),
common.ShapedType([640, 640], common.ElementType.i8),
common.ShapedType([2, 4096, 640], common.ElementType.i32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i8),
common.ShapedType([640, 640], tuner_ctx.type.i8),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i32),
common.DispatchKind.broadcast_rhs_mmt,
)

Expand Down Expand Up @@ -422,7 +436,7 @@ def test_apply_params_broadcast_rhs_mmt() -> None:
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable


def test_detect_broadcast_rhs_mmt() -> None:
def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_lines = [
r"%18 = tensor.empty() : tensor<2x1024x10240xi32>",
r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 128, 128]]>} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>",
Expand Down
Loading

0 comments on commit 64e8189

Please sign in to comment.