Skip to content

Commit

Permalink
[tuner]: use lowering config binding
Browse files Browse the repository at this point in the history
Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu committed Nov 28, 2024
1 parent 1896d7a commit 589900d
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 81 deletions.
40 changes: 29 additions & 11 deletions tuner/tuner/candidate_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
def apply_configuration(
template: list[str], configuration: Configuration, tile_sizes: list[int]
) -> str:
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count
tune_logger.info(f"Applying: {configuration}")
expr0 = re.compile(
r"<intrinsic = #iree_gpu\.mma_layout<(.+)>, subgroup_m_count = ([0-9]+), subgroup_n_count = ([0-9]+)>"
Expand All @@ -52,7 +55,7 @@ def apply_configuration(
expr2 = re.compile(r"tile_sizes = \[\[([0-9]+)(, ([0-9]+))+\]\]")
expr3 = re.compile(r"gpu_pipeline_options = #iree_gpu\.pipeline_options<([^>]*)>")
expr4 = re.compile(r"\"amdgpu-waves-per-eu\" = \"([0-9])\"")
repl0 = f"<intrinsic = {configuration.intrinsic}, subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>"
repl0 = f"<intrinsic = {intrinsic}, subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>"
repl1 = f'LLVMGPUVectorDistribute workgroup_size = [{", ".join(map(str, configuration.workgroup_size))}] subgroup_size = {configuration.subgroup_size},'
repl2 = f'tile_sizes = [[{", ".join(map(str, tile_sizes))}]]'
repl3 = f"gpu_pipeline_options = {configuration.gpu_pipeline_options}"
Expand Down Expand Up @@ -116,6 +119,9 @@ def get_transform_function_mmt(
self, problem_size: ProblemSize, functionName: str, configuration: Configuration
) -> str:
tile_sizes = ", ".join(map(str, get_mmt_tile_sizes(configuration)))
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -131,8 +137,8 @@ def get_transform_function_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %matmul, %config : !transform.any_op, !transform.any_param
Expand Down Expand Up @@ -186,6 +192,9 @@ def get_transform_function_conv(
output = f"tensor<{dynamic_batch_output_ty}>"

tile_sizes = ", ".join(map(str, self.get_conv_tile_sizes(configuration)))
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -204,8 +213,8 @@ def get_transform_function_conv(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %conv, %config : !transform.any_op, !transform.any_param
Expand Down Expand Up @@ -245,6 +254,9 @@ def get_transform_function_broadcast_rhs_mmt(
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -265,8 +277,8 @@ def get_transform_function_broadcast_rhs_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand Down Expand Up @@ -329,6 +341,9 @@ def get_transform_function_batch_mmt(
configuration: Configuration,
) -> str:
tile_sizes = ", ".join(map(str, get_batch_mmt_tile_sizes(configuration)))
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -345,8 +360,8 @@ def get_transform_function_batch_mmt(
translation_info = #iree_codegen.translation_info<LLVMGPUVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %generic, %config : !transform.any_op, !transform.any_param
Expand Down Expand Up @@ -395,6 +410,9 @@ def get_transform_function_batch_matmul(
tile_sizes = ", ".join(
map(str, get_contract_tile_sizes(configuration, tile_dims))
)
intrinsic = configuration.intrinsic
subgroup_m_count = configuration.subgroup_m_count
subgroup_n_count = configuration.subgroup_n_count

wg_x, wg_y, wg_z = configuration.workgroup_size
extra_config = get_pipeline_config(configuration)
Expand All @@ -413,8 +431,8 @@ def get_transform_function_batch_matmul(
translation_info = #iree_codegen.translation_info<LLVMGPUPadAndVectorDistribute
workgroup_size = [{wg_x}, {wg_y}, {wg_z}] subgroup_size = {configuration.subgroup_size},
{{mma_schedule = #iree_gpu.mma_schedule<
intrinsic = {configuration.intrinsic},
subgroup_m_count = {configuration.subgroup_m_count}, subgroup_n_count = {configuration.subgroup_n_count}>
intrinsic = {intrinsic},
subgroup_m_count = {subgroup_m_count}, subgroup_n_count = {subgroup_n_count}>
{extra_config}}}>
> -> !transform.any_param
transform.yield %batch_matmul, %config : !transform.any_op, !transform.any_param
Expand Down
140 changes: 112 additions & 28 deletions tuner/tuner/candidate_gen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,25 @@ def test_apply_params_mmt(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 8),
ir.IntegerAttr.get(tuner_ctx.type.i32, 8),
ir.IntegerAttr.get(tuner_ctx.type.i32, 8),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 16),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=16,
workgroup_size=[16, 16, 1],
intrinsic=mma_attr,
tile_sizes=[8, 8, 8],
subgroup_m_count=16,
subgroup_n_count=16,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(
prefetch_shared_memory=True
),
Expand Down Expand Up @@ -104,13 +116,24 @@ def test_apply_params_conv(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 464),
ir.IntegerAttr.get(tuner_ctx.type.i32, 320),
ir.IntegerAttr.get(tuner_ctx.type.i32, 16),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4),
}
lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=mma_attr,
tile_sizes=[464, 320, 16],
subgroup_m_count=1,
subgroup_n_count=4,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(
reorder_workgroups_strategy=iree_gpu.ReorderWorkgroupsStrategyAttr.get(
iree_gpu.ReorderWorkgroupsStrategy.Transpose
Expand Down Expand Up @@ -171,13 +194,25 @@ def test_apply_params_contract(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 480),
ir.IntegerAttr.get(tuner_ctx.type.i32, 384),
ir.IntegerAttr.get(tuner_ctx.type.i32, 32),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 1),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 4),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[256, 1, 1],
intrinsic=mma_attr,
tile_sizes=[480, 384, 32],
subgroup_m_count=1,
subgroup_n_count=4,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)
Expand Down Expand Up @@ -220,13 +255,26 @@ def test_apply_params_batch_matmul(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_32x32x8_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 416),
ir.IntegerAttr.get(tuner_ctx.type.i32, 320),
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)

config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=mma_attr,
tile_sizes=[416, 320, 128],
subgroup_m_count=2,
subgroup_n_count=2,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)
Expand Down Expand Up @@ -272,13 +320,25 @@ def test_apply_params_batch_mmt_float(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
ir.IntegerAttr.get(tuner_ctx.type.i32, 64),
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=2,
)
Expand Down Expand Up @@ -322,13 +382,25 @@ def test_apply_params_batch_mmt_int(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
ir.IntegerAttr.get(tuner_ctx.type.i32, 64),
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=4,
)
Expand Down Expand Up @@ -395,13 +467,25 @@ def test_apply_params_broadcast_rhs_mmt(tuner_ctx: common.TunerContext) -> None:

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_I32_32x32x16_I8
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
lowering_config_dict = {
"mma_kind": mma_attr,
"workgroup": ir.ArrayAttr.get(
[
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
ir.IntegerAttr.get(tuner_ctx.type.i32, 64),
ir.IntegerAttr.get(tuner_ctx.type.i32, 128),
]
),
"reduction": ir.ArrayAttr.get([]),
"subgroup_m_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
"subgroup_n_count": ir.IntegerAttr.get(tuner_ctx.type.i32, 2),
}

lowering_config_attrs = ir.DictAttr.get(lowering_config_dict)
config = common.Configuration(
subgroup_size=64,
workgroup_size=[128, 2, 1],
intrinsic=mma_attr,
tile_sizes=[128, 64, 128],
subgroup_m_count=2,
subgroup_n_count=2,
lowering_config=iree_gpu.LoweringConfigAttr.get(lowering_config_attrs),
gpu_pipeline_options=iree_gpu.PipelineOptionsAttr.get(),
waves_per_eu=4,
)
Expand Down
34 changes: 21 additions & 13 deletions tuner/tuner/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,26 +105,34 @@ def is_comptible(mma_intrinsic: iree_gpu.MMAIntrinsic) -> bool:
return list(filter(is_comptible, mma_intrinsics))


class ReorderWorkgroupsStrategy(Enum):
NONE = 0
SWIZZLE = 1
TRANSPOSE = 2

def __str__(self) -> str:
return self.name.title()


@dataclass
class Configuration:
subgroup_size: int
workgroup_size: list[int]
intrinsic: iree_gpu.MMAAttr
tile_sizes: list[int]
subgroup_m_count: int
subgroup_n_count: int
lowering_config: iree_gpu.LoweringConfigAttr
gpu_pipeline_options: iree_gpu.PipelineOptionsAttr
waves_per_eu: int

@property
def intrinsic(self) -> iree_gpu.MMAAttr:
return self.lowering_config.attributes["mma_kind"]

@property
def tilesize_workgroup(self) -> list[int]:
return [attr.value for attr in self.lowering_config.attributes["workgroup"]]

@property
def tilesize_reduction(self) -> list[int]:
return [attr.value for attr in self.lowering_config.attributes["reduction"]]

@property
def subgroup_m_count(self) -> int:
return self.lowering_config.attributes["subgroup_m_count"].value

@property
def subgroup_n_count(self) -> int:
return self.lowering_config.attributes["subgroup_n_count"].value


def get_pipeline_config(configuration: Configuration) -> str:
extra_config = ""
Expand Down
Loading

0 comments on commit 589900d

Please sign in to comment.