Skip to content

Commit

Permalink
[HotFix] Skip sw pipeline for dlight gemm for low SM
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Feb 22, 2024
1 parent 59c3556 commit 7999113
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions python/tvm/dlight/gpu/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
i0, i1, i2, i3 = sch.split(i, factors=i_factors)
j0, j1, j2, j3 = sch.split(j, factors=j_factors)
k0, k1 = sch.split(k, k_factors)
sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])
if target.arch.startswith("sm_") and int(target.arch[-2:]) > 75:
sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)

Expand Down Expand Up @@ -631,10 +632,11 @@ def apply( # pylint: disable=too-many-locals,missing-docstring
i0, i1, i2, i3 = sch.split(i, factors=i_factors)
j0, j1, j2, j3 = sch.split(j, factors=j_factors)
k0, k1 = sch.split(k, k_factors)
sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])
if target.arch.startswith("sm_") and int(target.arch[-2:]) > 75:
sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])

sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)

Expand Down

0 comments on commit 7999113

Please sign in to comment.