Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tuner]: use iree_gpu.MMAIntrinsic and iree_gpu.MMAAttr #605

Merged
merged 5 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def apply_configuration(
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
Expand Down Expand Up @@ -119,7 +119,6 @@ def get_transform_function_mmt(

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)

return f"""
transform.named_sequence @{functionName}(%matmul: !transform.any_op {{transform.readonly}}) -> (!transform.any_op, !transform.any_param) {{
%mmt = transform.include @match_mmt_f16_f16_f32 failures(propagate) (%matmul) : (!transform.any_op) -> !transform.any_op
Expand All @@ -132,7 +131,7 @@ def get_transform_function_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -205,7 +204,7 @@ def get_transform_function_conv(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -266,7 +265,7 @@ def get_transform_function_broadcast_rhs_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -346,7 +345,7 @@ def get_transform_function_batch_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down Expand Up @@ -414,7 +413,7 @@ def get_transform_function_batch_matmul(
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = #iree_gpu.mma_layout<{configuration.intrinsic}>,
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
Expand Down
29 changes: 22 additions & 7 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Generator

from iree.compiler import ir # type: ignore
from iree.compiler.dialects import iree_gpu # type: ignore

from . import candidate_gen
from . import common
Expand Down Expand Up @@ -45,10 +46,12 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:

M, N, K = 2048, 1280, 1280

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
Expand Down Expand Up @@ -97,10 +100,12 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:

n, oh, ow, oc, fh, fw, ic = 2, 64, 64, 640, 3, 3, 640

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -161,10 +166,12 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.contraction,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
Expand Down Expand Up @@ -208,10 +215,12 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_matmul,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
intrinsic=mma_attr,
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -258,10 +267,12 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -306,10 +317,12 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.batch_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down Expand Up @@ -377,10 +390,12 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:
common.DispatchKind.broadcast_rhs_mmt,
)

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
Expand Down
72 changes: 11 additions & 61 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,74 +85,24 @@ def MNK(self) -> tuple[int, int, int]:
return (self.matmul_size.M, self.matmul_size.N, self.matmul_size.K)


@dataclass
class MfmaIntrinsic:
output_type: ir.IntegerType | ir.FloatType
m: int
n: int
k: int
input_type: ir.IntegerType | ir.FloatType

def __str__(self) -> str:
input = str(self.input_type).upper()
output = str(self.output_type).upper()
return f"MFMA_{output}_{self.m}x{self.n}x{self.k}_{input}"

@staticmethod
def mfma_f32_16x16x16_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 16, 16, 16, f16)

@staticmethod
def mfma_f32_32x32x8_f16():
f16 = ir.F16Type.get()
f32 = ir.F32Type.get()
return MfmaIntrinsic(f32, 32, 32, 8, f16)

@staticmethod
def mfma_i32_16x16x32_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 16, 16, 32, i8)

@staticmethod
def mfma_i32_32x32x16_i8():
i32 = ir.IntegerType.get_signless(32)
i8 = ir.IntegerType.get_signless(8)
return MfmaIntrinsic(i32, 32, 32, 16, i8)

@staticmethod
def all():
return [
MfmaIntrinsic.mfma_f32_16x16x16_f16(),
MfmaIntrinsic.mfma_f32_32x32x8_f16(),
MfmaIntrinsic.mfma_i32_16x16x32_i8(),
MfmaIntrinsic.mfma_i32_32x32x16_i8(),
]


def get_compatible_mfma_intrinsics(
problem_size: ProblemSize,
mma_intrinsics: list[iree_gpu.MMAIntrinsic],
) -> list[MfmaIntrinsic]:
available_mma_intrinsics = [str(mma) for mma in mma_intrinsics]

def is_compatible(intrinsic: MfmaIntrinsic) -> bool:
if problem_size.res_type.element_type != intrinsic.output_type:
) -> list[iree_gpu.MMAIntrinsic]:
def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
mma_attr = iree_gpu.MMAIntrinsicAttr.get(mma_intrinsic).mma
a_type, b_type, c_type = mma_attr.abc_element_types
if problem_size.res_type.element_type != c_type:
return False
if problem_size.dispatch_kind != DispatchKind.batch_matmul:
if problem_size.lhs_type.element_type != intrinsic.input_type:
return False
if problem_size.rhs_type.element_type != intrinsic.input_type:
if (
problem_size.lhs_type.element_type != a_type
or problem_size.rhs_type.element_type != b_type
):
return False

if str(intrinsic) not in available_mma_intrinsics:
return False

return True

return list(filter(is_compatible, MfmaIntrinsic.all()))
return list(filter(is_comptible, mma_intrinsics))


class ReorderWorkgroupsStrategy(Enum):
Expand Down Expand Up @@ -197,7 +147,7 @@ def __str__(self) -> str:
class Configuration:
subgroup_size: int
workgroup_size: list[int]
intrinsic: MfmaIntrinsic
intrinsic: iree_gpu.MMAAttr
tile_sizes: list[int]
subgroup_m_count: int
subgroup_n_count: int
Expand Down
25 changes: 11 additions & 14 deletions tuner/tuner/common_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
Usage: python -m pytest candidate_gen_test.py
Usage: python -m pytest common_test.py
bangtianliu marked this conversation as resolved.
Show resolved Hide resolved
"""

import pytest
Expand Down Expand Up @@ -72,10 +72,12 @@ def test_gpu_pipeline_options() -> None:


def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
config = common.Configuration(
subgroup_size=32,
workgroup_size=[16, 16, 1],
intrinsic=common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
intrinsic=mma_attr,
tile_sizes=[4, 8, 16],
subgroup_m_count=1,
subgroup_n_count=1,
Expand All @@ -97,11 +99,6 @@ def test_get_pipeline_config(mlir_ctx: ir.Context) -> None:
)


def test_mfma_intrinsic_to_str(mlir_ctx: ir.Context) -> None:
assert str(common.MfmaIntrinsic.mfma_f32_16x16x16_f16()) == "MFMA_F32_16x16x16_F16"
assert str(common.MfmaIntrinsic.mfma_i32_32x32x16_i8()) == "MFMA_I32_32x32x16_I8"


def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
assert common.get_compatible_mfma_intrinsics(
common.ProblemSize(
Expand All @@ -116,8 +113,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -133,8 +130,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
],
) == [
common.MfmaIntrinsic.mfma_i32_16x16x32_i8(),
common.MfmaIntrinsic.mfma_i32_32x32x16_i8(),
iree_gpu.MMAIntrinsic.MFMA_I32_16x16x32_I8,
iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -150,8 +147,8 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_16x16x16_f16(),
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16,
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert common.get_compatible_mfma_intrinsics(
Expand All @@ -166,7 +163,7 @@ def test_get_compatible_mfma_intrinsics(tuner_ctx: common.TunerContext) -> None:
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
],
) == [
common.MfmaIntrinsic.mfma_f32_32x32x8_f16(),
iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16,
]

assert (
Expand Down
Loading
Loading