diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py index 9318b91492..79e8ce37a8 100644 --- a/python/tvm/dlight/gpu/matmul.py +++ b/python/tvm/dlight/gpu/matmul.py @@ -410,10 +410,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring i0, i1, i2, i3 = sch.split(i, factors=i_factors) j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + if target.arch.startswith("sm_") and int(target.arch[-2:]) > 75: + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3) @@ -631,10 +632,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring i0, i1, i2, i3 = sch.split(i, factors=i_factors) j0, j1, j2, j3 = sch.split(j, factors=j_factors) k0, k1 = sch.split(k, k_factors) - sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) - sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) - sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) - sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) + if target.arch.startswith("sm_") and int(target.arch[-2:]) > 75: + sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6]) + sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1]) + sch.annotate(k1, "software_pipeline_order", [0, 1, 2]) + sch.annotate(k1, "software_pipeline_stage", [0, 0, 1]) sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)