diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py index b063883800..25f61472d4 100644 --- a/python/tvm/dlight/gpu/gemv.py +++ b/python/tvm/dlight/gpu/gemv.py @@ -125,8 +125,9 @@ def normalize( if c_loops: return None loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor]) - # we expect the inner most dim to be grouped atm - assert not (is_reduction ^ is_inner_reduction) + # we only support the inner most dim being grouped atm + if is_reduction ^ is_inner_reduction: + return None c_loops.append(c_loop) if is_reduction: r_loops.append(loop) diff --git a/tests/python/dlight/test_gpu_gemv.py b/tests/python/dlight/test_gpu_gemv.py index fd6850ac60..2fe7c06e33 100644 --- a/tests/python/dlight/test_gpu_gemv.py +++ b/tests/python/dlight/test_gpu_gemv.py @@ -535,9 +535,47 @@ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"), lv576: T mod = tvm.IRModule({"main": before}) with Target("nvidia/geforce-rtx-3090-ti"): mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) - mod.show(black_format=False) tvm.ir.assert_structural_equal(mod["main"], expected) +def test_autogptq_decode_gemv(): + # fmt: off + @T.prim_func(private=True) + def func(lv9: T.Buffer((T.int64(512), T.int64(4096)), "uint32"), lv10: T.Buffer((T.int64(32), T.int64(512)), "uint32"), lv11: T.Buffer((T.int64(32), T.int64(4096)), "float16"), lv12: T.Buffer((T.int64(4096),), "uint32"), lv8: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), lv1613: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"), p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16")): + T.func_attr({"tir.noalias": T.bool(True)}) + # with T.block("root"): + decode_intermediate = T.alloc_buffer((T.int64(4096), T.int64(4096)), "float16") + var_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16") + for i, j in T.grid(T.int64(4096), T.int64(4096)): + with T.block("decode"): + v_i, v_j = T.axis.remap("SS", [i, j]) + D = T.Buffer((T.int64(4096),), "uint32") + T.reads(lv9[v_i // T.int64(8), v_j], lv10[D[v_i], v_j // T.int64(8)], lv12[v_i], lv11[D[v_i], v_j]) + T.writes(decode_intermediate[v_i, v_j]) + decode_intermediate[v_i, v_j] = (T.Cast("float16", T.bitwise_and(T.shift_right(lv9[v_i // T.int64(8), v_j], T.Cast("uint32", v_i % T.int64(8) * T.int64(4))), T.uint32(15))) - (T.Cast("float16", T.bitwise_and(T.shift_right(lv10[lv12[v_i], v_j // T.int64(8)], T.Cast("uint32", v_j % T.int64(8) * T.int64(4))), T.uint32(15))) + T.float16(1))) * lv11[lv12[v_i], v_j] + for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096), T.int64(4096)): + with T.block("matmul"): + v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k]) + T.reads(lv8[v_i0, v_i1, v_k], decode_intermediate[v_k, v_i2]) + T.writes(var_matmul_intermediate[v_i0, v_i1, v_i2]) + with T.init(): + var_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0) + var_matmul_intermediate[v_i0, v_i1, v_i2] = var_matmul_intermediate[v_i0, v_i1, v_i2] + lv8[v_i0, v_i1, v_k] * decode_intermediate[v_k, v_i2] + for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)): + with T.block("T_add"): + v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(lv1613[v_ax0, v_ax1, v_ax2], var_matmul_intermediate[v_ax0, v_ax1, v_ax2]) + T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2]) + p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv1613[v_ax0, v_ax1, v_ax2] + var_matmul_intermediate[v_ax0, v_ax1, v_ax2] + # fmt: on + + # The GeMV rule does not yet support the inner dim being grouped. + # So the rule is expected to skip transforming this function. + mod = tvm.IRModule({"main": func}) + with Target("nvidia/geforce-rtx-3090-ti"): + mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod) + tvm.ir.assert_structural_equal(mod["main"], func) + + if __name__ == "__main__": tvm.testing.main()