Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner] Use ir.(Integer|Float)Type for element types #554

Merged
merged 3 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 45 additions & 41 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,48 +517,52 @@ def tune(

with ir.Context() as ctx:
tuner_context = TunerContext(ctx, tune_logger)
mlir_module: ir.Module = parse_mlir(mlir_text, tuner_context)
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)
with parse_mlir(mlir_text, tuner_context) as mlir_module:
# Save the input file as the first candidate.
with open(path.join(output, f"0.mlir"), "w") as f:
f.write(mlir_text)

dispatch_tuner_registry = DispatchTunerRegistry()
dispatch_tuner_registry.register(
[
MmtTuner(),
ConvTuner(),
ContractionTuner(lhs_dims, rhs_dims, tile_dims),
BatchMmtTuner(),
BatchMatmulTuner(lhs_dims, rhs_dims, tile_dims),
]
)

walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tuner_context, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(problem_size, mlir_template, config)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")
walk_result: OpWalkResult = walk_mlir_op(
mlir_module, dispatch_tuner_registry
)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
):
if i >= limit:
break
tune_logger.info(f"Solution #{i+1}: {config}")
configs.append(config)
tf_mlir = dispatch_tuner.apply_params(
problem_size, mlir_template, config
)

with open(path.join(output, f"{i+1}.mlir"), "w") as f:
f.write(tf_mlir.modified)
with open(path.join(output, f"{i+1}_config.mlir"), "w") as f:
f.write(tf_mlir.embeddable)

with open(path.join(output, "configs.pkl"), "wb") as file:
pickle.dump(configs, file)

tune_logger.info(f"Generated {len(configs)} candidates")
tune_logger.info(f"Configurations .pkl is stored in {output}/configs.pkl")


def main():
Expand Down
72 changes: 43 additions & 29 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,31 @@

import pytest

from typing import Generator

from iree.compiler import ir # type: ignore

from . import candidate_gen
from . import common


@pytest.fixture
def tuner_ctx() -> Generator[common.TunerContext, None, None]:
from logging import Logger
from unittest.mock import MagicMock

with ir.Context() as ctx:
logger: Logger = MagicMock(spec=Logger)
yield common.TunerContext(ctx, logger)


def remove_comments(mlir: str) -> str:
return "\n".join(
filter(lambda x: not x.lstrip().startswith("//"), mlir.splitlines())
)


def test_apply_params_mmt() -> None:
def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_32x32x8_F16>, subgroup_m_count = 16, subgroup_n_count = 16>",
"<LLVMGPUVectorDistribute workgroup_size = [16, 16] subgroup_size = 16,",
Expand All @@ -44,9 +58,9 @@ def test_apply_params_mmt() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(M, N, K),
common.ShapedType([M, K], common.ElementType.f16),
common.ShapedType([N, K], common.ElementType.f16),
common.ShapedType([M, N], common.ElementType.f32),
common.ShapedType([M, K], tuner_ctx.type.f16),
common.ShapedType([N, K], tuner_ctx.type.f16),
common.ShapedType([M, N], tuner_ctx.type.f32),
common.DispatchKind.mmt,
)
tf_mlir = candidate_gen.MmtTuner().apply_params(problem_size, mlir_template, config)
Expand All @@ -73,7 +87,7 @@ def test_apply_params_mmt() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "8"}' in modified


def test_apply_params_conv() -> None:
def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 16, subgroup_n_count = 16>",
"<LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64,",
Expand All @@ -98,9 +112,9 @@ def test_apply_params_conv() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(oh * ow, oc, fh * fw * ic),
common.ShapedType([n, oh + 2, ow + 2, oc], common.ElementType.f16),
common.ShapedType([fh, fw, ic, oc], common.ElementType.f16),
common.ShapedType([n, oh, ow, oc], common.ElementType.f32),
common.ShapedType([n, oh + 2, ow + 2, oc], tuner_ctx.type.f16),
common.ShapedType([fh, fw, ic, oc], tuner_ctx.type.f16),
common.ShapedType([n, oh, ow, oc], tuner_ctx.type.f32),
common.DispatchKind.conv,
)
tf_mlir = candidate_gen.ConvTuner().apply_params(
Expand Down Expand Up @@ -130,7 +144,7 @@ def test_apply_params_conv() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_contract() -> None:
def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 2, subgroup_n_count = 2>}>",
"<LLVMGPUVectorDistribute workgroup_size = [128, 2, 1] subgroup_size = 64,",
Expand All @@ -141,9 +155,9 @@ def test_apply_params_contract() -> None:
tile_dims = "*mnk"
problem_size = common.ProblemSize(
common.MatmulSize(2048, 3840, 1280),
common.ShapedType([2, 1024, 1280], common.ElementType.f16),
common.ShapedType([3, 20, 64, 1280], common.ElementType.f16),
common.ShapedType([3, 2, 20, 1024, 64], common.ElementType.f32),
common.ShapedType([2, 1024, 1280], tuner_ctx.type.f16),
common.ShapedType([3, 20, 64, 1280], tuner_ctx.type.f16),
common.ShapedType([3, 2, 20, 1024, 64], tuner_ctx.type.f32),
common.DispatchKind.contraction,
)

Expand Down Expand Up @@ -177,7 +191,7 @@ def test_apply_params_contract() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in new_mlir


def test_apply_params_batch_matmul() -> None:
def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -188,9 +202,9 @@ def test_apply_params_batch_matmul() -> None:
tile_dims = "bmnk"
problem_size = common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], common.ElementType.f16),
common.ShapedType([64, 640, 320], common.ElementType.f16),
common.ShapedType([64, 968, 320], common.ElementType.f32),
common.ShapedType([64, 968, 640], tuner_ctx.type.f16),
common.ShapedType([64, 640, 320], tuner_ctx.type.f16),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
)

Expand Down Expand Up @@ -228,7 +242,7 @@ def test_apply_params_batch_matmul() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_batch_mmt_float() -> None:
def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -238,9 +252,9 @@ def test_apply_params_batch_mmt_float() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.f16),
common.ShapedType([2, 640, 640], common.ElementType.f16),
common.ShapedType([2, 4096, 640], common.ElementType.f32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.f16),
common.ShapedType([2, 640, 640], tuner_ctx.type.f16),
common.ShapedType([2, 4096, 640], tuner_ctx.type.f32),
common.DispatchKind.batch_mmt,
)

Expand Down Expand Up @@ -276,7 +290,7 @@ def test_apply_params_batch_mmt_float() -> None:
assert '{llvm_func_attrs = {"amdgpu-waves-per-eu" = "2"}' in modified


def test_apply_params_batch_mmt_int() -> None:
def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -286,9 +300,9 @@ def test_apply_params_batch_mmt_int() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.i8),
common.ShapedType([2, 640, 640], common.ElementType.i8),
common.ShapedType([2, 4096, 640], common.ElementType.i32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i8),
common.ShapedType([2, 640, 640], tuner_ctx.type.i8),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i32),
common.DispatchKind.batch_mmt,
)

Expand Down Expand Up @@ -347,7 +361,7 @@ def test_apply_params_batch_mmt_int() -> None:
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable


def test_apply_params_broadcast_rhs_mmt() -> None:
def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_template = [
"<intrinsic = #iree_gpu.mma_layout<MFMA_I32_16x16x32_I8>, subgroup_m_count = 4, subgroup_n_count = 1>}>",
"<LLVMGPUVectorDistribute workgroup_size = [64, 4, 1] subgroup_size = 64,",
Expand All @@ -357,9 +371,9 @@ def test_apply_params_broadcast_rhs_mmt() -> None:

problem_size = common.ProblemSize(
common.MatmulSize(4096, 640, 640, 2),
common.ShapedType([2, 4096, 640], common.ElementType.i8),
common.ShapedType([640, 640], common.ElementType.i8),
common.ShapedType([2, 4096, 640], common.ElementType.i32),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i8),
common.ShapedType([640, 640], tuner_ctx.type.i8),
common.ShapedType([2, 4096, 640], tuner_ctx.type.i32),
common.DispatchKind.broadcast_rhs_mmt,
)

Expand Down Expand Up @@ -422,7 +436,7 @@ def test_apply_params_broadcast_rhs_mmt() -> None:
assert "workgroup_size = [128, 2, 1] subgroup_size = 64" in embeddable


def test_detect_broadcast_rhs_mmt() -> None:
def test_detect_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
mlir_lines = [
r"%18 = tensor.empty() : tensor<2x1024x10240xi32>",
r"%19 = linalg.fill {lowering_config = #iree_codegen.lowering_config<tile_sizes = [[1, 64, 128, 128]]>} ins(%c0_i32 : i32) outs(%18 : tensor<2x1024x10240xi32>) -> tensor<2x1024x10240xi32>",
Expand Down
Loading
Loading