From bd359685c7d2b6cb666dfa4eca1b6902b22d2a4f Mon Sep 17 00:00:00 2001 From: Bohan Hou Date: Tue, 8 Aug 2023 05:02:24 -0700 Subject: [PATCH] [Unity][DLight] Update gemv rule (#15490) --- python/tvm/dlight/gpu/gemv.py | 380 ++++++++++++++----- python/tvm/dlight/gpu/utils.py | 2 + tests/python/dlight/test_gpu_gemv.py | 536 ++++++++++++++++++--------- 3 files changed, 653 insertions(+), 265 deletions(-) diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index 13dee1cd54..b063883800 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -16,6 +16,7 @@ # under the License. """A rule for GEMV and DecodeGEMV.""" import re +from functools import reduce from typing import List, Optional, Union from tvm import DataType, arith, ir, tir @@ -124,6 +125,8 @@ def normalize( if c_loops: return None loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) + # we expect the inner most dim to be grouped atm + assert not (is_reduction ^ is_inner_reduction) c_loops.append(c_loop) if is_reduction: r_loops.append(loop) @@ -169,6 +172,10 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- return None block_info = block_infos[0] + if len(block_info.iters) not in [2, 3]: + # either [B, S, R] = [B, S, R] * [B, R] + # or [S, R] = [S, R] * [R] + return None block = block_info.block_rv vector_input_buffers = is_gemv(sch, block_info) if vector_input_buffers is None: @@ -179,14 +186,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return- # Step 2. Do the scheduling if is_inner_reduction: - # print(func) self.sch_inner_reduction(sch, target, block, vector_input_buffers, epilogue) return sch else: # TODO: Need to handle GEMV with KN layout return None - def sch_inner_reduction( # pylint: disable=too-many-arguments + def sch_inner_reduction( # pylint: disable=too-many-arguments, invalid-name, unused-argument self, sch: tir.Schedule, target: Target, @@ -195,106 +201,282 @@ def sch_inner_reduction( # pylint: disable=too-many-arguments epilogue_info: Optional[BlockInfo], ): """Schedule the inner reduction block.""" - # pylint: disable=invalid-name - _, s, r, _ = sch.get_loops(block) - # TODO: make it tunable - vec_bytes = 16 if target.kind.name == "cuda" else 8 - unroll_number = 256 if target.kind.name == "cuda" else 64 + + def get_max_factor(n, factors): + factors = sorted(factors, reverse=True) + for factor in factors: + if n % factor == 0: + return factor + return 1 + + def apply( + sch: tir.Schedule, + gemv, + TAG_S, + TAG_R, + TS, + TR, + TILE_S, + TILE_R, + VEC_LOAD, + VEC_C, + LOAD_V_SHARED, + LOAD_V_VEC, + UNROLL, + ): + # rfactor: reduce to tx * vec_c + _, s, r, c = sch.get_loops(block=gemv) + s = sch.fuse(_, s) + r = sch.fuse(r, c) + bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S], preserve_unit_iters=True) + r, tr, tile_r_vec_n, vec_c = sch.split( + r, factors=[None, TR, TILE_R // VEC_C, VEC_C], preserve_unit_iters=True + ) + sch.reorder(r, tile_r_vec_n, tr, vec_c) + tr_vec_c = sch.fuse(tr, vec_c) + rf = sch.rfactor(tr_vec_c, 0) + + # rfactor: reduce to tx + bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + rf2 = sch.rfactor(tr, 0) + + # bind, vectorize compute + bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf) + tr, vec_c = sch.split(tr_vec_c, factors=[TR, None], preserve_unit_iters=True) + sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c) + sch.bind(bx, "blockIdx.x") + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_c) + + shared_mem_usage = 0 + for buf in vector_input_buffers: + buf_size = reduce( + lambda x, y: x * y, buf.shape, tir.IntImm(buf.shape[0].dtype, 1) + ) * get_bytes(buf.dtype) + shared_mem_usage += buf_size + LOAD_V_SHARED = ( + LOAD_V_SHARED + and isinstance(shared_mem_usage, tir.IntImm) + and shared_mem_usage.value <= target.max_shared_memory_per_block + ) + + # vectorize load A + # (TODO) this is now actually problematic since the number of loops is dependent on the + # number of dimensions of A_q + Aq_local = sch.cache_read(rf, read_buffer_index=1, storage_scope="local") + sch.compute_at(Aq_local, r, preserve_unit_loops=True) + s_local, r_local = sch.get_loops(block=Aq_local)[-2:] + s_local, vec_load = sch.split( + s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True + ) + sch.reorder(s_local, r_local, vec_load) # either s_local or r_local should be 1 + sch.vectorize(vec_load) + + # load vector into shared memory, shape should be the whole vector + if LOAD_V_SHARED: + assert len(vector_input_buffers) == 1 + V_shared = sch.cache_read(rf, read_buffer_index=0, storage_scope="shared") + sch.compute_at(V_shared, tr, preserve_unit_loops=True) + l = sch.get_loops(block=V_shared)[-1] + loop: tir.For = sch.get(l) + if isinstance(loop.extent, tir.IntImm): + # avoid introducing predicates when vector length is too large + vec_length = max( + min( + get_max_factor( + (int)(loop.extent), + [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS * TR * 8], + ) + // TS + // TR, + LOAD_V_VEC, + ), + 1, + ) + else: + vec_length = LOAD_V_VEC + if TAG_R == "threadIdx.x": + _, ty, tx, vec = sch.split( + l, factors=[None, TS, TR, vec_length], preserve_unit_iters=True + ) + else: + _, ty, tx, vec = sch.split( + l, factors=[None, TR, TS, vec_length], preserve_unit_iters=True + ) + sch.bind(ty, "threadIdx.y") + sch.bind(tx, "threadIdx.x") + sch.vectorize(vec) + + # reduce tile_s * tr * vec to tile_s * tr + sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True) + tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + tile_s, vec_s = sch.split( + tile_s, + factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])], + preserve_unit_iters=True, + ) + sch.reorder(ts, tr, tile_s, vec_s, vec_c) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + sch.vectorize(vec_s) + + # reduce tile_s * tr to tile_s + sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True) + tr, *ts_tile_s = sch.get_loops(block=gemv)[1:] + ts_tile_s = sch.fuse(*ts_tile_s) + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.reorder(tile_s, ts, tr) + sch.bind(ts, TAG_S) + sch.bind(tr, TAG_R) + + sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3]) + sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1]) + + sch.set_scope(rf, buffer_index=0, storage_scope="local") + sch.set_scope(rf2, buffer_index=0, storage_scope="local") + + unroll_factor = UNROLL + + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf)[3], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], + ann_key="pragma_auto_unroll_max_step", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(rf2)[3], ann_key="pragma_unroll_explicit", ann_val=1 + ) + + if LOAD_V_SHARED: + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], + ann_key="pragma_unroll_explicit", + ann_val=unroll_factor, + ) + sch.annotate( + block_or_loop=sch.get_loops(V_shared)[-4], ann_key="pragma_vectorize", ann_val=1 + ) + + # Schedule epilogue + if epilogue_info is not None: + epilogue = epilogue_info.block_rv + if is_broadcast_epilogue(sch, block, epilogue): + sch.reverse_compute_at(epilogue, bx) + sch.set_scope(block, 0, "shared") + _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name + _, tx = sch.split(sch.fuse(*s), factors=[None, TX]) + sch.bind(tx, "threadIdx.x") + else: + sch.reverse_compute_at(epilogue, bx, preserve_unit_loops=True) + ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:]) + ts_tile_s = sch.get_loops(epilogue)[-1] + ts, tile_s = sch.split(ts_tile_s, factors=[TS, None], preserve_unit_iters=True) + sch.bind(ts, TAG_S) + sch.set_scope(block, 0, "local") + # pylint: enable=invalid-name + return sch def get_extent(loop_rv: tir.schedule.LoopRV): loop: tir.For = sch.get(loop_rv) - return loop.extent.value if isinstance(loop.extent, tir.IntImm) else 1 + return loop.extent.value if isinstance(loop.extent, tir.IntImm) else loop.extent # Specify the `len_tx` and `len_ty` according to the loop extent - len_s, len_r = get_extent(s), get_extent(r) - if len_r >= 4096 and len_r % 128 == 0: - len_tx = 128 - elif 1024 < len_r <= 2048 and len_r % 64 == 0: - len_tx = 64 - else: - len_tx = 32 - - if len_s >= 4096: - len_ty = 8 - else: - len_ty = min(len_s, 4) - - # Use `split_k` to prevent too large shared memory usage - split_k: int = 4 - - _, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True) - # Schedule the RF block - rf = sch.rfactor(tx, 0) - batch, bx, r, tx, _ = sch.get_loops(rf) - sch.reorder(bx, tx, r) - ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True) - bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True) - - sch.bind(batch, "blockIdx.y") - sch.bind(bx, "blockIdx.x") - sch.bind(ty, "threadIdx.y") - sch.bind(tx, "threadIdx.x") - sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number) - sch.annotate(ro, "pragma_unroll_explicit", 1) - + batch, s, r, c = sch.get_loops(block=block) + len_batch, len_s, len_r, len_c = ( + get_extent(batch), + get_extent(s), + get_extent(r), + get_extent(c), + ) + len_S = len_batch * len_s + len_R = len_r * len_c + + TAG_S, TAG_R = "threadIdx.y", "threadIdx.x" if target.kind.name == "cuda": - # Cache read the vector - def cache_shared(index: int): - block: tir.Block = sch.get(rf) - type_bytes: int = get_bytes(block.reads[index].buffer.dtype) - cache = sch.cache_read(rf, index, "shared") - sch.compute_at(cache, ro, preserve_unit_loops=True) - fused = sch.fuse(*sch.get_loops(cache)[5:]) - loop: tir.For = sch.get(fused) - vec_length = vec_bytes // type_bytes - if isinstance(loop.extent, tir.IntImm): - # avoid introducing predicates when vector length is too large - vec_length = min(loop.extent // len_ty // len_tx, vec_length) - _, _ty, _tx, _vec = sch.split(fused, [None, len_ty, len_tx, vec_length]) - sch.bind(_ty, "threadIdx.y") - sch.bind(_tx, "threadIdx.x") - sch.vectorize(_vec) - - def cache_local(index: int): - block: tir.Block = sch.get(rf) - type_bytes: int = get_bytes(block.reads[index].buffer.dtype) - vec_length = vec_bytes // type_bytes - cache = sch.cache_read(rf, index, "local") - sch.compute_at(cache, ri, preserve_unit_loops=True) - fused = sch.fuse(*sch.get_loops(cache)[6:]) - loop: tir.For = sch.get(fused) - if isinstance(loop.extent, tir.IntImm) and loop.extent.value % vec_length == 0: - _, _vec = sch.split(fused, [None, vec_length]) - sch.vectorize(_vec) - elif isinstance(loop.extent, tir.IntImm) and loop.extent.value < vec_length: - sch.vectorize(fused) - - for buffer in vector_input_buffers: - index = vector_input_buffers.index(buffer) - cache_shared(index) - cache_local(index) - - # TODO: cache scale buffer in Decode-GEMV to shared memory - - sch.set_scope(rf, 0, "local") - sch.decompose_reduction(rf, ro) - # Schedule the write back block - sch.reverse_compute_at(block, ty, preserve_unit_loops=True) - _, _, _, tx, *s = sch.get_loops(block) - s = sch.fuse(*s) - sch.reorder(s, tx) - sch.bind(tx, "threadIdx.x") - # Schedule epilogue - if epilogue_info is not None: - epilogue = epilogue_info.block_rv - if is_broadcast_epilogue(sch, block, epilogue): - sch.reverse_compute_at(epilogue, bx) - sch.set_scope(block, 0, "shared") - _, _, *s = sch.get_loops(epilogue) # pylint: disable=invalid-name - _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx]) - sch.bind(tx, "threadIdx.x") - else: - # NOTE: Need to ensure tx_len == 32, so that can use `local` stage here - sch.reverse_compute_at(epilogue, ty) - sch.set_scope(block, 0, "local") - # pylint: enable=invalid-name + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 64 + else: + TS, TR = 16, 32 + elif target.kind.name == "metal": + VEC_C = 2 + LOAD_V_SHARED = True + LOAD_V_VEC = 4 + UNROLL = 256 + TS, TR = 64, 8 + elif target.kind.name == "rocm": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 8 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 1, 128 + else: + TS, TR = 8, 64 + elif target.kind.name == "opencl" and "android" in str(target.host): + TAG_S, TAG_R = "threadIdx.x", "threadIdx.y" + VEC_C = 8 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 8 + TS, TR = 2, 32 + elif target.kind.name == "vulkan": + VEC_C = 4 + LOAD_V_SHARED = True + LOAD_V_VEC = 4 + UNROLL = 256 + if isinstance(len_S, int): + if len_S > len_R: + TS, TR = 4, 32 + else: + TS, TR = 16, 32 + else: + VEC_C = 1 + LOAD_V_SHARED = False + LOAD_V_VEC = -1 + UNROLL = 64 + TS, TR = 1, 64 + + if not isinstance(len_S, int): + TS, TR = 1, 64 + TILE_S, TILE_R = ( + 1, + len_c + if len_c > 1 + else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8]) // TR, 1), + ) + VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C) + VEC_LOAD = 1 + + return apply( + sch, + gemv=block, + TAG_S=TAG_S, + TAG_R=TAG_R, + TS=TS, + TR=TR, + TILE_S=TILE_S, + TILE_R=TILE_R, + VEC_LOAD=VEC_LOAD, + VEC_C=VEC_C, + LOAD_V_SHARED=LOAD_V_SHARED, + LOAD_V_VEC=LOAD_V_VEC, + UNROLL=UNROLL, + ) diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py index 4fcc762942..9f9a9c5ae4 100644 --- a/python/tvm/dlight/gpu/utils.py +++ b/python/tvm/dlight/gpu/utils.py @@ -51,6 +51,8 @@ def suggest_threads_per_block( ) -> List[int]: if target.kind.name == "cuda": threads = 256 + elif target.kind.name == "rocm": + threads = 256 else: threads = 64 results: List[Optional[int]] = [] diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index 6cb5cceb43..fd6850ac60 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -90,61 +90,92 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n)) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n), "float16", scope="local") - var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 32, 1, n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1, n), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1, n), "float16", scope="local") + lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16", scope="local") lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16", scope="shared") - lv1637_shared_local = T.alloc_buffer((1, 32, 1, 128), "float16", scope="local") - for ax0_fused in T.thread_binding(32, thread="blockIdx.y"): - for ax1_fused_0 in T.thread_binding(n, thread="blockIdx.x"): - for ax1_fused_1 in T.thread_binding(1, thread="threadIdx.y"): - for ax2_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): - vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused]) - v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) - var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = T.float16(0) - for ax2_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_ax3_fused_0 in range(1): - for ax0_ax1_ax2_ax3_fused_1 in T.thread_binding(1, thread="threadIdx.y"): - for ax0_ax1_ax2_ax3_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_ax3_fused_3 in T.vectorized(1): - with T.block("lv1637_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(32, ax0_fused) - v2 = T.axis.spatial(1, 0) - v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax0_ax1_ax2_ax3_fused_0 * 32 + ax0_ax1_ax2_ax3_fused_1 * 32 + ax0_ax1_ax2_ax3_fused_2 + ax0_ax1_ax2_ax3_fused_3) - T.reads(lv1637[v0, v1, v2, v3]) - T.writes(lv1637_shared[v0, v1, v2, v3]) - lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3] - for ax2_fused_0_1 in range(1): - for ax0_ax1_ax2_ax3_fused in T.vectorized(1): - with T.block("lv1637_shared_local"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(32, ax0_fused) - v2 = T.axis.spatial(1, 0) - v3 = T.axis.spatial(128, ax2_fused_0_0 * 32 + ax2_fused_1) - T.reads(lv1637_shared[v0, v1, v2, v3]) - T.writes(lv1637_shared_local[v0, v1, v2, v3]) - lv1637_shared_local[v0, v1, v2, v3] = lv1637_shared[v0, v1, v2, v3] - for u in range(1): - with T.block("NT_matmul_rf_update"): - vax2_fused_1, v0 = T.axis.remap("SS", [ax2_fused_1, ax0_fused]) - v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1) - vax2_fused_0 = T.axis.reduce(4, ax2_fused_0_0 + ax2_fused_0_1) - T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1], lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1], lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1]) - T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) - var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] + lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1] * lv1638[0, v0, v1, vax2_fused_0 * 32 + vax2_fused_1] - for ax1_ax2_fused in range(1): - for ax0 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul"): - vax2_fused_1, v0, v1 = T.axis.remap("RSS", [ax0, ax0_fused, ax1_fused_0]) - T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]) - T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1]) - with T.init(): - var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0) - var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] + for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32, thread="blockIdx.x"): + for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1, thread="threadIdx.y"): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1, ax2 in T.grid(1, 1, 1): + for ax3_0 in T.serial(1, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax3_1 in T.thread_binding(1, thread="threadIdx.y"): + for ax3_2 in T.thread_binding(64, thread="threadIdx.x"): + for ax3_3 in T.vectorized(2): + with T.block("lv1637_shared"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) + v2 = T.axis.spatial(1, ax2) + v3 = T.axis.spatial(128, ax3_0 * 128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3) + T.reads(lv1637[v0, v1, v2, v3]) + T.writes(lv1637_shared[v0, v1, v2, v3]) + lv1637_shared[v0, v1, v2, v3] = lv1637[v0, v1, v2, v3] + for ax0_fused_ax1_fused_fused_2_init in range(1): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init) + v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) // n) + v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2_init) % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = T.float16(0) + for ax2_fused_u_fused_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2): + for ax2_1 in T.vectorized(1): + with T.block("lv1638_local"): + v0 = T.axis.spatial(1, ax0) + v1 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n + ax1) + v2 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1) + v3 = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3) + T.reads(lv1638[v0, v1, v2, v3]) + T.writes(lv1638_local[v0, v1, v2, v3]) + lv1638_local[v0, v1, v2, v3] = lv1638[v0, v1, v2, v3] + for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2 in T.grid(1, 1): + for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1) + v0 = T.axis.spatial(32, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) // n) + v1 = T.axis.spatial(n, (ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 + ax0_fused_ax1_fused_fused_2) % n) + vax2_fused_u_fused_2, vax2_fused_u_fused_0 = T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0]) + T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused], lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]) + T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1]) + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused, 0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] * lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] + for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_ax3_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_ax3_fused_1_1 in T.vectorized(1): + with T.block("NT_matmul_rf_init"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0) + v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) + v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = T.float16(0) + for ax1 in range(2): + with T.block("NT_matmul_rf_update"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) + v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1], var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]) + T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] = var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1] + for ax1_ax2_fused_1 in range(1): + for ax1_ax2_fused_0 in T.thread_binding(1, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("NT_matmul"): + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.reduce(64, ax0) + v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) + v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1]) + T.writes(var_NT_matmul_intermediate_local[0, v0, 0, v1]) + with T.init(): + var_NT_matmul_intermediate_local[0, v0, 0, v1] = T.float16(0) + var_NT_matmul_intermediate_local[0, v0, 0, v1] = var_NT_matmul_intermediate_local[0, v0, 0, v1] + var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0, 0, v0, 0, v1] + for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"): + for ax0_ax1_fused_1 in range(1): with T.block("compute"): - v0, v1 = T.axis.remap("SS", [ax0_fused, ax1_fused_0]) + v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 // n) + v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n) T.reads(var_NT_matmul_intermediate_local[0, v0, 0, v1], lv1614[0, 0, 0, v1]) T.writes(var_compute_intermediate[0, v0, 0, v1]) var_compute_intermediate[0, v0, 0, v1] = T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] * T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1])) @@ -152,10 +183,10 @@ def expected(lv1637: T.Buffer((1, 32, 1, 128), "float16"), p_lv1638: T.handle, p # fmt: on -class TestDecodeGEMV1(BaseBeforeAfter): +def test_decode_gemv1(): # fmt: off - @T.prim_func + @T.prim_func(private=True) def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): @@ -175,72 +206,95 @@ def before(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] * p_output0_intermediate[v_i2, v_k] - @T.prim_func + @T.prim_func(private=True) def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572: T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"), var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): - var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 22016), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 22016), "float16", scope="local") + lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local") lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") - lv1654_shared_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding(2752, thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"): - for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) - for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_fused_0 in range(1): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_3 in T.vectorized(4): - with T.block("lv1654_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3) - T.reads(lv1654[v0, v1, v2]) - T.writes(lv1654_shared[v0, v1, v2]) - lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] - for ax1_0_fused_0_1 in range(4): - for ax0_ax1_ax2_fused_0 in range(1): - for ax0_ax1_ax2_fused_1 in T.vectorized(8): - with T.block("lv1654_shared_local"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) - T.reads(lv1654_shared[v0, v1, v2]) - T.writes(lv1654_shared_local[v0, v1, v2]) - lv1654_shared_local[v0, v1, v2] = lv1654_shared[v0, v1, v2] - for ax1_1 in range(8): - with T.block("NT_matmul_rf_update"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) - vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1) - vax1_1 = T.axis.reduce(8, ax1_1) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) - for ax1_fused in range(1): - for ax0 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul"): - vax1_0_fused_1 = T.axis.reduce(32, ax0) - v0 = T.axis.spatial(22016, ax0_fused_0 * 8 + ax0_fused_1) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - T.writes(var_NT_matmul_intermediate[0, 0, v0]) - with T.init(): - var_NT_matmul_intermediate[0, 0, v0] = T.float16(0) - var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + for u_fused_ax0_fused_fused_0 in T.thread_binding(5504, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1 in T.grid(1, 1): + for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax2_2 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_3 in T.vectorized(8): + with T.block("lv1654_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3) + T.reads(lv1654[v0, v1, v2]) + T.writes(lv1654_shared[v0, v1, v2]) + lv1654_shared[v0, v1, v2] = lv1654[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(1): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_0, ax1 in T.grid(1, 1): + for ax0_1 in T.vectorized(1): + with T.block("lv571_local"): + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) + v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) + T.reads(lv571[v0, v1]) + T.writes(lv571_local[v0, v1]) + lv571_local[v0, v1] = lv571[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) + T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv572[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1_1 in T.vectorized(1): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(64, ax0) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) + for ax1 in range(4): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] + for ax1_fused_1 in range(1): + for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(64, ax0) + v0 = T.axis.spatial(22016, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + T.writes(var_NT_matmul_intermediate[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate[0, 0, v0] = T.float16(0) + var_NT_matmul_intermediate[0, 0, v0] = var_NT_matmul_intermediate[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) -class TestDecodeGEMV2(BaseBeforeAfter): + +def test_decode_gemv2(): # fmt: off - @T.prim_func + @T.prim_func(private=True) def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.noalias": T.bool(True)}) # with T.block("root"): @@ -267,73 +321,223 @@ def before(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128) T.writes(p_output0_intermediate[v_i0, v_i1, v_i2]) p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32", var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) - @T.prim_func + @T.prim_func(private=True) def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772: T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"), p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")): T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) # with T.block("root"): var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32000), "float16", scope="local") - var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1, 32000), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1, 32000), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1, 32000), "float16", scope="local") + lv771_local = T.alloc_buffer((32000, 512), "uint32", scope="local") lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared") - lv3216_shared_local = T.alloc_buffer((1, 1, 4096), "float16", scope="local") - for u_fused in T.thread_binding(1, thread="blockIdx.y"): - for ax0_fused_0 in T.thread_binding(4000, thread="blockIdx.x"): - for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"): - for ax1_0_fused_1 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul_rf_init"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) - T.reads() - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0) - for ax1_0_fused_0_0 in T.serial(4, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): - for ax0_ax1_ax2_fused_0 in range(1): - for ax0_ax1_ax2_fused_1 in T.thread_binding(8, thread="threadIdx.y"): - for ax0_ax1_ax2_fused_2 in T.thread_binding(32, thread="threadIdx.x"): - for ax0_ax1_ax2_fused_3 in T.vectorized(4): - with T.block("lv3216_shared"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128 + ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3) - T.reads(lv3216[v0, v1, v2]) - T.writes(lv3216_shared[v0, v1, v2]) - lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2] - for ax1_0_fused_0_1 in range(4): - for ax0_ax1_ax2_fused_0 in range(1): - for ax0_ax1_ax2_fused_1 in T.vectorized(8): - with T.block("lv3216_shared_local"): - v0 = T.axis.spatial(1, 0) - v1 = T.axis.spatial(1, 0) - v2 = T.axis.spatial(4096, ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 + ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1) - T.reads(lv3216_shared[v0, v1, v2]) - T.writes(lv3216_shared_local[v0, v1, v2]) - lv3216_shared_local[v0, v1, v2] = lv3216_shared[v0, v1, v2] - for ax1_1 in range(8): - with T.block("NT_matmul_rf_update"): - vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1) - v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) - vax1_0_fused_0 = T.axis.reduce(16, ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1) - vax1_1 = T.axis.reduce(8, ax1_1) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0], lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1], lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) - T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32]) - for ax1_fused in range(1): - for ax0 in T.thread_binding(32, thread="threadIdx.x"): - with T.block("NT_matmul"): - vax1_0_fused_1 = T.axis.reduce(32, ax0) - v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) - T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]) - T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) - with T.init(): - var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) - var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] + for u_fused_ax0_fused_fused_0 in T.thread_binding(8000, thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1 in T.grid(1, 1): + for ax2_0 in T.serial(2, annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(4, thread="threadIdx.y"): + for ax2_2 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_3 in T.vectorized(8): + with T.block("lv3216_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(4096, ax2_0 * 2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3) + T.reads(lv3216[v0, v1, v2]) + T.writes(lv3216_shared[v0, v1, v2]) + lv3216_shared[v0, v1, v2] = lv3216[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(1): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(4): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_0 in T.serial(8, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_0, ax1 in T.grid(1, 1): + for ax0_1 in T.vectorized(1): + with T.block("lv771_local"): + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) + v1 = T.axis.spatial(512, ax1_0_fused_ax1_1_fused_0 * 64 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) + T.reads(lv771[v0, v1]) + T.writes(lv771_local[v0, v1]) + lv771_local[v0, v1] = lv771[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) + T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4], lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, 0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + vax1_0_fused_ax1_1_fused_2 // 2], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) % 8) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv772[v0, (vax1_0_fused_ax1_1_fused_0 * 512 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 + vax1_0_fused_ax1_1_fused_2 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4) // 32]) + for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + for ax2_fused_1_0 in T.serial(1, annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1_1 in T.vectorized(1): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(64, ax0) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = T.float16(0) + for ax1 in range(4): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]) + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0] + for ax1_fused_1 in range(1): + for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0 in T.thread_binding(64, thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(64, ax0) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0]) + T.writes(var_NT_matmul_intermediate_local[0, 0, v0]) + with T.init(): + var_NT_matmul_intermediate_local[0, 0, v0] = T.float16(0) + var_NT_matmul_intermediate_local[0, 0, v0] = var_NT_matmul_intermediate_local[0, 0, v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, 0, 0, v0] + for ax0_fused_0 in T.thread_binding(4, thread="threadIdx.y"): + for ax0_fused_1 in range(1): with T.block("compute"): - v0 = T.axis.spatial(32000, ax0_fused_0 * 8 + ax0_fused_1) + v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 * 4 + ax0_fused_0 + ax0_fused_1) T.reads(var_NT_matmul_intermediate_local[0, 0, v0]) T.writes(p_output0_intermediate[0, 0, v0]) p_output0_intermediate[0, 0, v0] = T.Cast("float32", var_NT_matmul_intermediate_local[0, 0, v0]) # fmt: on + mod = tvm.IRModule({"main": before}) + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], expected) + + +def test_decode_gemv3(): + # fmt: off + + @T.prim_func(private=True) + def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096), T.int64(11008)), "float16") + var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(11008)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j // T.int64(32)]) + T.writes(p_output0_intermediate_1[v_i, v_j]) + p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j // T.int64(32)] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(11008)): + with T.block("NT_matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2, v_k]) + T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] * p_output0_intermediate_1[v_i2, v_k] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv570[v_ax0, v_ax1, v_ax2], var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0, v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2] + + @T.prim_func(private=True) + def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574: T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)}) + # with T.block("root"): + var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") + var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128), T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") + var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32), T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local") + lv575_local = T.alloc_buffer((T.int64(4096), T.int64(1376)), "uint32", scope="local") + lv574_shared = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16", scope="shared") + for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(256), thread="blockIdx.x"): + for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax2_0 in T.serial(T.int64(22), annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}): + for ax2_1 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax2_2 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_3 in T.vectorized(T.int64(1)): + with T.block("lv574_shared"): + v0, v1 = T.axis.remap("SS", [ax0, ax1]) + v2 = T.axis.spatial(T.int64(11008), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) + ax2_2 + ax2_3) + T.where((ax2_0 * T.int64(16) + ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(11008)) + T.reads(lv574[v0, v1, v2]) + T.writes(lv574_shared[v0, v1, v2]) + lv574_shared[v0, v1, v2] = lv574[v0, v1, v2] + for u_fused_ax0_fused_fused_2_init in range(T.int64(1)): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2_init) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)): + for ax0_1 in T.vectorized(T.int64(1)): + with T.block("lv575_local"): + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1) + v1 = T.axis.spatial(T.int64(1376), ax1_0_fused_ax1_1_fused_0 * T.int64(32) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1) + T.reads(lv575[v0, v1]) + T.writes(lv575_local[v0, v1]) + lv575_local[v0, v1] = lv575[v0, v1] + for u_fused_ax0_fused_fused_2, ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)): + for ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused = T.axis.spatial(T.int64(128), ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + u_fused_ax0_fused_fused_2) + vax1_0_fused_ax1_1_fused_0, vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0, ax1_0_fused_ax1_1_fused_2]) + T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0], lv574_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)], lv575_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], lv576[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)]) + T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0]) + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused, T.int64(0), T.int64(0), v0] + lv574_shared[T.int64(0), T.int64(0), vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)] * ((T.Cast("float16", T.bitwise_and(T.shift_right(lv575_local[v0, vax1_0_fused_ax1_1_fused_0 * T.int64(32) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) + vax1_0_fused_ax1_1_fused_2 // T.int64(2)], T.Cast("uint32", (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) % T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v0, (vax1_0_fused_ax1_1_fused_0 * T.int64(256) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) * T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)) // T.int64(32)]) + for ax2_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + for ax2_fused_1_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}): + for ax2_fused_1_1 in T.vectorized(T.int64(1)): + with T.block("NT_matmul_rf_init"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.spatial(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads() + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = T.float16(0) + for ax1 in range(T.int64(4)): + with T.block("NT_matmul_rf_update"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 = T.axis.remap("SR", [ax0, ax1]) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0]) + T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, T.int64(0), T.int64(0), v0] + for ax1_fused_1 in range(T.int64(1)): + for ax1_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"): + with T.block("NT_matmul"): + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 = T.axis.reduce(T.int64(32), ax0) + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0 + ax1_fused_1) + T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0]) + T.writes(var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + with T.init(): + var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = T.float16(0) + var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] = var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0, T.int64(0), T.int64(0), v0] + for ax0_fused_0 in T.thread_binding(T.int64(16), thread="threadIdx.y"): + for ax0_fused_1 in range(T.int64(1)): + with T.block("T_add"): + v0 = T.axis.spatial(T.int64(4096), u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0 + ax0_fused_1) + T.reads(lv570[T.int64(0), T.int64(0), v0], var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]) + T.writes(p_output0_intermediate[T.int64(0), T.int64(0), v0]) + p_output0_intermediate[T.int64(0), T.int64(0), v0] = lv570[T.int64(0), T.int64(0), v0] + var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0] + + # fmt: on + + mod = tvm.IRModule({"main": before}) + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + mod.show(black_format=False) + tvm.ir.assert_structural_equal(mod["main"], expected) + if __name__ == "__main__": tvm.testing.main()