Skip to content

Commit

Permalink
[TKW] Move IGEMM conv impl to common place. (#295)
Browse files Browse the repository at this point in the history
Move TKW IGEMM conv impl from the test folder to some common place to
allow it to be reused outside the tests (e.g. in iree-kernel-benchmark).

Not sure what the proper place for it, suggestions are welcome.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
  • Loading branch information
Hardcode84 authored Nov 27, 2024
1 parent e3b6c87 commit 9e79f4e
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 122 deletions.
5 changes: 5 additions & 0 deletions iree/turbine/kernel/wave/templates/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
167 changes: 167 additions & 0 deletions iree/turbine/kernel/wave/templates/conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
# Copyright 2024 The IREE Authors
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import iree.turbine.kernel as tk
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from typing import Any, Optional
from iree.turbine.kernel.lang.global_symbols import *


def get_igemm_conv2d(
layout: str,
n: int,
h: int,
w: int,
c: int,
hf: int,
wf: int,
nf: int,
stride: int,
mem_space: tkl.IndexSymbol = SHARED_ADDRESS_SPACE,
block_m: Optional[int] = None,
block_n: Optional[int] = None,
block_k: Optional[int] = None,
ratio_m: Optional[int] = None,
ratio_n: Optional[int] = None,
) -> tuple["LaunchableWave", dict[tkl.IndexSymbol, Any]]:
cf = c
padding = 0 # TODO: only pad=0 is supported for now

sym = tkl.sym
N, C, H, W = sym.N, sym.C, sym.H, sym.W
NF, HF, WF = sym.NF, sym.HF, sym.WF

H_OUT = (H + 2 * padding - HF) // stride + 1
W_OUT = (W + 2 * padding - WF) // stride + 1
SZ_OUT = H_OUT * W_OUT

K = HF * WF * C
M = SZ_OUT * N

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

# Align C dim reading pattern to be contiguous for nhwc_hwcf pattern.
x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j % C,
H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j % C, HF: (j // C) % WF, WF: (j // C) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, NF: j},
outputs={
N: i // SZ_OUT,
NF: j,
H_OUT: (i % SZ_OUT) % W_OUT,
W_OUT: (i % SZ_OUT) // W_OUT,
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

if layout == "nchw_fchw":
x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32]
elif layout == "nhwc_hwcf":
x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32]
else:
raise ValueError(f"Unsupported layout: {layout}")

if block_m is None:
block_m = 64

if block_n is None:
block_n = 128

if block_k is None:
block_k = 32

if ratio_m is None:
ratio_m = 2

if ratio_n is None:
ratio_n = 2

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(ratio_n, ratio_m, 1),
)
]

@tkw.wave(constraints)
def conv(
x: x_type,
we: we_type,
out: out_type,
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
a_reg = tkw.read(
x,
mapping=x_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
b_reg = tkw.read(
we,
mapping=w_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(
repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD
)

symbols = {
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: block_m,
BLOCK_N: block_n,
BLOCK_K: block_k,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: mem_space,
}

return conv, symbols
137 changes: 15 additions & 122 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.wave_sim import wave_sim
from iree.turbine.kernel.wave.templates.conv import get_igemm_conv2d
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel.wave.iree_utils import generate_iree_ref
from iree.turbine.kernel.wave.utils import (
Expand Down Expand Up @@ -912,114 +913,27 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
convRef.weight = torch.nn.Parameter(we)
out_ref = convRef(x).detach().to(torch.float32)

sym = tkl.sym
N, C, H, W = sym.N, sym.C, sym.H, sym.W
NF, HF, WF = sym.NF, sym.HF, sym.WF

H_OUT = (H + 2 * padding - HF) // stride + 1
W_OUT = (W + 2 * padding - WF) // stride + 1
SZ_OUT = H_OUT * W_OUT

K = HF * WF * C
M = SZ_OUT * N

i = tkw.IndexMapping.iterator(0)
j = tkw.IndexMapping.iterator(1)

# Align C dim reading pattern to be contiguous for nhwc_hwcf pattern.
x_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={
N: i // SZ_OUT,
C: j % C,
H: (i % SZ_OUT) % W_OUT * stride + (j // C) % WF,
W: (i % SZ_OUT) // W_OUT * stride + (j // C) // WF,
},
outputs={M: i, K: j},
)
w_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={NF: i % NF, C: j % C, HF: (j // C) % WF, WF: (j // C) // WF},
outputs={NF: i, K: j},
)
out_mapping = tkw.IndexMapping(
num_iterators=2,
inputs={M: i, NF: j},
outputs={
N: i // SZ_OUT,
NF: j,
H_OUT: (i % SZ_OUT) % W_OUT,
W_OUT: (i % SZ_OUT) // W_OUT,
},
)

# Workgroup tile sizes
BLOCK_M = tkl.sym.BLOCK_M
BLOCK_N = tkl.sym.BLOCK_N
BLOCK_K = tkl.sym.BLOCK_K
# Address space (for GPU, shared(1) or global(0))
ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE
# Other hyperparameters
ELEMS_PER_THREAD = tkl.sym.ELEMS_PER_THREAD

if layout == "nchw_fchw":
x_type = tkl.Memory[N, C, H, W, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[NF, C, HF, WF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, NF, H_OUT, W_OUT, GLOBAL_ADDRESS_SPACE, tkl.f32]
pass # Nothing
elif layout == "nhwc_hwcf":
x_type = tkl.Memory[N, H, W, C, ADDRESS_SPACE, tkl.f16]
we_type = tkl.Memory[HF, WF, C, NF, ADDRESS_SPACE, tkl.f16]
out_type = tkl.Memory[N, H_OUT, W_OUT, NF, GLOBAL_ADDRESS_SPACE, tkl.f32]
x = torch.permute(x, (0, 2, 3, 1)).contiguous()
we = torch.permute(we, (2, 3, 1, 0)).contiguous()
out_ref = torch.permute(out_ref, (0, 2, 3, 1)).contiguous()
else:
raise ValueError(f"Invalid layout: {layout}")

ratio_m = 2
ratio_n = 2

# Expose user-constraints
constraints: list[tkw.Constraint] = []
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 1)]
constraints += [tkw.WorkgroupConstraint(NF, BLOCK_N, 0)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / ratio_m)]
constraints += [tkw.WaveConstraint(NF, BLOCK_N / ratio_n)]
constraints += [tkw.TilingConstraint(K, BLOCK_K)]

constraints += [
tkw.HardwareConstraint(
threads_per_wave=64,
waves_per_block=(ratio_n, ratio_m, 1),
)
]

@tkw.wave(constraints)
def conv(
x: x_type,
we: we_type,
out: out_type,
):
c_reg = tkl.Register[M, NF, tkl.f32](0.0)

@tkw.reduction(K, init_args=[c_reg])
def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
a_reg = tkw.read(
x,
mapping=x_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
b_reg = tkw.read(
we,
mapping=w_mapping,
elements_per_thread=ELEMS_PER_THREAD,
)
acc = tkw.mma(a_reg, b_reg, acc)
return acc

tkw.write(
repeat, out, mapping=out_mapping, elements_per_thread=ELEMS_PER_THREAD
)
conv, symbols = get_igemm_conv2d(
layout=layout,
n=n,
h=h,
w=w,
c=c,
hf=hf,
wf=wf,
nf=nf,
stride=stride,
mem_space=mem_space,
)

config = get_default_run_config()

Expand All @@ -1037,28 +951,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
)

with tk.gen.TestLaunchContext(
{
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 64,
BLOCK_N: 128,
BLOCK_K: 32,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: mem_space,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
},
symbols,
canonicalize=True,
run=True,
run_bench=run_bench,
Expand Down

0 comments on commit 9e79f4e

Please sign in to comment.