-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TE][CreatePrimFunc] Fix loop carried dependency case with nested block levels #17474
[TE][CreatePrimFunc] Fix loop carried dependency case with nested block levels #17474
Conversation
@tvm-bot re-run |
29c06ed
to
ed8bbc7
Compare
Hi, I posted this issue on TVM community, and thanks for providing fast reply and fix! I added a follow-up in the forum but the content is still waiting for approval to be public. So I forward it here: When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.
I also applied your fix to tvm, and when I re-compile the model, new error appeared as follows:
I also printed out the generated TIR by _DebugDump in compile pass, and I compared it with the code in the test code you provided, I found they are almost the same. The difference is in my TIR, there are some cast to int64 like
And in my new error message, the TIR printed out in terminal shows there are still undefined variable usage for Could you help look into further on this issue? Many thanks! |
Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg |
Hi,
Here is the code mlc-llm used to compile model. The pipeline is:
Looks like the error occurs at Thanks! |
Hi, sorry for bothering @wrongtest-intellif , but I have a new finding. In
And I found Before
After
Could you give some clues about how to fix this bug? Thanks you very much! |
This is because except the primfunc creation, the schedule system and auto-schedule rules also has potential issues to not correctly take loop carried dependency into consideration. Since they are generally developed to optimize static shape workloads. We are trying to find proper way out for such workloads. |
Thank you very much for your reply! |
cc @Hzfengsy can you help to review this PR |
4fe1133
to
97a2799
Compare
97a2799
to
7f71550
Compare
Here is some updates for new change. Because TE could define axis with it's domain depend on previous axes. It is a problem to convert such compute op to one single block since the block iter vars should be insensitive to their relative positions defined. It seems better to create nested block levels to represent such workloads. And ensure each level's block iter vars are independent. For the adaptive pooling case in the context, changed
This decomposition could ensure independency for block vars. @T.prim_func
def tir_workload(x: T.Buffer((1, 1024, 16, 40), "float32"), adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32")):
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_sum_1"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)])
T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30):
with T.block("adaptive_pool_sum"):
v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1])
T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1])
with T.init():
adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0)
adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30)) I also try some tune methods:
def apply_trace(sch: tir.Schedule) -> None:
b0 = sch.get_block(name="adaptive_pool_sum_l1", func_name="main")
b1 = sch.get_block(name="adaptive_pool_avg", func_name="main")
b2 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b1, ann_key="schedule_rule")
v3 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4)
sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v3)
l4, l5, l6, l7 = sch.get_loops(block=b1)
l8 = sch.fuse(l4, l5, l6, l7, preserve_unit_iters=True)
l9, l10, l11 = sch.split(loop=l8, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l10, l11, l9)
sch.bind(loop=l10, thread_axis="blockIdx.x")
sch.bind(loop=l11, thread_axis="threadIdx.x")
l12, l13, l14, l15 = sch.get_loops(block=b0)
l16 = sch.fuse(l12, l13, l14, l15, preserve_unit_iters=True)
l17, l18, l19 = sch.split(loop=l16, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
sch.reorder(l18, l19, l17)
sch.bind(loop=l18, thread_axis="blockIdx.x")
sch.bind(loop=l19, thread_axis="threadIdx.x")
sch.enter_postproc()
b20 = sch.get_block(name="root", func_name="main")
sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.unroll_explicit")
b21, b22 = sch.get_child_blocks(b20)
l23, l24, l25 = sch.get_loops(block=b21)
l26, l27, l28 = sch.get_loops(block=b22)
b29 = sch.get_block(name="adaptive_pool_sum", func_name="main")
l30, l31 = sch.get_loops(block=b29)
b32 = sch.decompose_reduction(block=b29, loop=l30) |
If the loop domain depends on other loops, currently there is missing transformations in
CreatePrimFunc
, which lead to undefined variables in lowering.https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784