Skip to content

Commit

Permalink
[tuner]: use python binding to select mma intrinsics (#586)
Browse files Browse the repository at this point in the history
This PR is relevant to the task in
#453: " Use IREE attributes for
MFMA intrinsics in the tuner".

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu authored Nov 22, 2024
1 parent 779adc3 commit 530f4bd
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 9 deletions.
10 changes: 9 additions & 1 deletion tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_codegen # type: ignore

from .common import *
from .dispatch_constraints import *
from .dispatch_parser import *
Expand Down Expand Up @@ -535,13 +537,19 @@ def tune(

walk_result: OpWalkResult = walk_mlir_op(mlir_module, dispatch_tuner_registry)

variant_op_list = iree_codegen.get_executable_variant_ops(mlir_module)
assert len(variant_op_list) == 1, "Expect one executable variant op"
variant_op = variant_op_list[0]
# Get the MMA intrinisic intructions supported by the target.
mma_list = iree_codegen.query_mma_intrinsics(variant_op)

dispatch_tuner = walk_result.dispatch_tuner
assert dispatch_tuner, "No suitable dispatch tuner found"
problem_size: ProblemSize = dispatch_tuner.get_shapes(mlir_template)
tune_logger.debug(str(problem_size))
configs = []
for i, config in enumerate(
generate_solutions(tune_logger, problem_size, num_subgroups)
generate_solutions(tune_logger, problem_size, num_subgroups, mma_list)
):
if i >= limit:
break
Expand Down
13 changes: 12 additions & 1 deletion tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

from iree.compiler import ir # type: ignore

from iree.compiler.dialects import iree_gpu # type: ignore


class CommonTypes:
def __init__(self, ctx: ir.Context):
Expand Down Expand Up @@ -130,7 +132,12 @@ def all():
]


def get_compatible_mfma_intrinsics(problem_size: ProblemSize) -> list[MfmaIntrinsic]:
def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
return False
Expand All @@ -139,6 +146,10 @@ def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
return False
if problem_size.rhs_type.element_type != intrinsic.input_type:
return False

if str(intrinsic) not in available_mma_intrinsics:
return False

return True

return list(filter(is_compatible, MfmaIntrinsic.all()))
Expand Down
51 changes: 48 additions & 3 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore


@pytest.fixture
Expand Down Expand Up @@ -109,7 +110,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.f16),
common.ShapedType([2048, 1280], tuner_ctx.type.f32),
common.DispatchKind.mmt,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
Expand All @@ -122,7 +127,11 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([1280, 1280], tuner_ctx.type.i8),
common.ShapedType([2048, 1280], tuner_ctx.type.i32),
common.DispatchKind.mmt,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
Expand All @@ -135,8 +144,44 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
)
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
]

assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
),
[
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
]

assert (
common.get_compatible_mfma_intrinsics(
common.ProblemSize(
common.MatmulSize(968, 320, 640, 64),
common.ShapedType([64, 968, 640], tuner_ctx.type.f32),
common.ShapedType([64, 640, 320], tuner_ctx.type.f32),
common.ShapedType([64, 968, 320], tuner_ctx.type.f32),
common.DispatchKind.batch_matmul,
),
[
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)
== []
)
15 changes: 12 additions & 3 deletions tuner/tuner/dispatch_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@
import z3 # type: ignore
from typing import Iterator


from iree.compiler.dialects import iree_gpu # type: ignore

from .common import *


Expand All @@ -18,8 +21,9 @@ def get_mfma_intrinsic_constraints(
intrinsic_m: z3.ArithRef,
intrinsic_n: z3.ArithRef,
intrinsic_k: z3.ArithRef,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> z3.BoolRef:
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size)
compatible_intrinsics = get_compatible_mfma_intrinsics(problem_size, mma_intrinsics)
assert len(compatible_intrinsics) > 0, "No compatible intrinsics found"
return z3.Or(
*(
Expand Down Expand Up @@ -68,6 +72,7 @@ def generate_constraints(
subgroup_m_count,
subgroup_n_count,
waves_per_eu,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
):
M, N, K = (
problem_size.matmul_size.M,
Expand All @@ -82,7 +87,7 @@ def generate_constraints(
constraints += [subgroup_size == 64, wg_threads <= 1024]
constraints += [
get_mfma_intrinsic_constraints(
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k
problem_size, intrinsic_mn, intrinsic_mn, intrinsic_k, mma_intrinsics
)
]
subgroup_k_count = 1
Expand Down Expand Up @@ -130,7 +135,10 @@ def generate_constraints(


def generate_solutions(
logger: logging.Logger, problem_size: ProblemSize, num_subgrups: int
logger: logging.Logger,
problem_size: ProblemSize,
num_subgrups: int,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> Iterator[Configuration]:
M, N, K = problem_size.MNK
logger.info(f"{M},{N},{K}")
Expand Down Expand Up @@ -168,6 +176,7 @@ def generate_solutions(
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
mma_intrinsics,
)
solver.add(z3.simplify(z3.And(constraints)))
logger.debug(f"Initial constraints: {solver}")
Expand Down
26 changes: 25 additions & 1 deletion tuner/tuner/dispatch_constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore

from . import common
from . import dispatch_constraints
Expand All @@ -37,7 +38,18 @@ def test_generate_solutions(tuner_ctx: common.TunerContext) -> None:
problem_size = common.ProblemSize(
matmul_size, lhs_type, rhs_type, res_type, common.DispatchKind.mmt
)
configs = dispatch_constraints.generate_solutions(tuner_ctx.logger, problem_size, 4)
configs = dispatch_constraints.generate_solutions(
tuner_ctx.logger,
problem_size,
4,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)

assert configs is not None


Expand Down Expand Up @@ -115,6 +127,12 @@ def test_generate_constraints_valid_input(tuner_ctx: common.TunerContext) -> Non
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)

solver = z3.Solver()
Expand Down Expand Up @@ -160,6 +178,12 @@ def test_generate_constraints_invalid_input(tuner_ctx: common.TunerContext) -> N
sg_m_cnt,
sg_n_cnt,
waves_per_eu,
[
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
)
constraints.append(m > 1000) # Adding an additional unsatisfiable constraint

Expand Down

0 comments on commit 530f4bd

Please sign in to comment.