Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TE][CreatePrimFunc] Fix loop carried dependency case with nested block levels #17474

Conversation

wrongtest-intellif
Copy link
Contributor

If the loop domain depends on other loops, currently there is missing transformations in CreatePrimFunc, which lead to undefined variables in lowering.

https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784

@wrongtest-intellif
Copy link
Contributor Author

@tvm-bot re-run

@wrongtest-intellif wrongtest-intellif force-pushed the fix_create_primfunc_with_loop_carried_deps branch from 29c06ed to ed8bbc7 Compare October 22, 2024 08:58
@SimonSongg
Copy link

SimonSongg commented Oct 22, 2024

If the loop domain depends on other loops, currently there is missing transformations in CreatePrimFunc, which lead to undefined variables in lowering.

https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784

Hi, I posted this issue on TVM community, and thanks for providing fast reply and fix! I added a follow-up in the forum but the content is still waiting for approval to be public. So I forward it here:

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

`Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/workspace/tvm/python/tvm/driver/build_module.py", line 297, in build
    rt_mod_host = _driver_ffi.tir_to_runtime(annotated_mods, target_host)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "/workspace/tvm/src/driver/driver_api.cc", line 532, in operator()
    return TIRToRuntime(inputs_arg, host_target);
  File "/workspace/tvm/src/driver/driver_api.cc", line 493, in tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
    auto pair = SplitMixedModule(ir_module, target, target_host);
  File "/workspace/tvm/src/driver/driver_api.cc", line 419, in tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
    mod_mixed = ApplyPasses(mod_mixed, MixedModulePassManager(mod_mixed, target));
  File "/workspace/tvm/src/driver/driver_api.cc", line 290, in tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
    mod = seq(std::move(mod));
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 435, in operator()
    func = MakePackedAPI(std::move(func));
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398, in tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
    ICHECK_EQ(undefined.size(), 0) << "In PrimFunc " << name_hint << " variables " << undefined
tvm.error.InternalError: Traceback (most recent call last):
  5: operator()
        at /workspace/tvm/src/driver/driver_api.cc:532
  4: tvm::TIRToRuntime(tvm::runtime::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)
        at /workspace/tvm/src/driver/driver_api.cc:493
  3: tvm::SplitMixedModule(tvm::IRModule, tvm::Target const&, tvm::Target const&)
        at /workspace/tvm/src/driver/driver_api.cc:419
  2: tvm::ApplyPasses(tvm::IRModule, tvm::transform::Sequential)
        at /workspace/tvm/src/driver/driver_api.cc:290
  1: operator()
        at /workspace/tvm/src/tir/transforms/make_packed_api.cc:435
  0: tvm::tir::MakePackedAPI(tvm::tir::PrimFunc)
        at /workspace/tvm/src/tir/transforms/make_packed_api.cc:398
  File "/workspace/tvm/src/tir/transforms/make_packed_api.cc", line 398
InternalError: Check failed: undefined.size() == 0 (2 vs. 0) : In PrimFunc default_function variables [ax2, ax3] are used, but are not passed in as API arguments`

I also applied your fix to tvm, and when I re-compile the model, new error appeared as follows:

`build
    relax.build(
  File "/workspace/tvm/python/tvm/relax/vm_build.py", line 335, in build
    mod = pipeline(mod)
  File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
  File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
    TVMAPISetLastPythonError(((void *)__pyx_v_err));
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/mnt/volumes/jointmodel/songtianchen/mlc-llm-dev-eagle-lpai/python/mlc_llm/compiler_pass/pipeline.py", line 181, in _pipeline
    mod = seq(mod)
  File "/workspace/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 263, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./packed_func.pxi", line 252, in tvm._ffi._cy3.core.FuncCall3
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "tvm/_ffi/_cython/core.cpp", line 7494, in __pyx_f_3tvm_4_ffi_4_cy3_4core_tvm_callback
    TVMAPISetLastPythonError(((void *)__pyx_v_err));
  File "tvm/_ffi/_cython/./packed_func.pxi", line 56, in tvm._ffi._cy3.core.tvm_callback
  File "/workspace/tvm/python/tvm/ir/transform.py", line 307, in _pass_func
    return inst.transform_module(mod, ctx)
  File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 71, in transform_module
    sch = _apply_rules(func, target, self.rules, tunable=False)
  File "/workspace/tvm/python/tvm/dlight/base/transform.py", line 87, in _apply_rules
    space = rule.apply(func, target, tunable)
  File "/workspace/tvm/python/tvm/dlight/gpu/general_reduction.py", line 114, in apply
    sch.compute_at(block, bx, preserve_unit_loops=True)
  File "/workspace/tvm/python/tvm/tir/schedule/_type_checker.py", line 340, in wrap
    return func(*args, **kwargs)
  File "/workspace/tvm/python/tvm/tir/schedule/schedule.py", line 2111, in compute_at
    _ffi_api.ScheduleComputeAt(  # type: ignore # pylint: disable=no-member
  File "tvm/_ffi/_cython/./packed_func.pxi", line 332, in tvm._ffi._cy3.core.PackedFuncBase.__call__
  File "tvm/_ffi/_cython/./packed_func.pxi", line 277, in tvm._ffi._cy3.core.FuncCall
  File "tvm/_ffi/_cython/./base.pxi", line 182, in tvm._ffi._cy3.core.CHECK_CALL
  File "/workspace/tvm/python/tvm/_ffi/base.py", line 481, in raise_last_ffi_error
    raise py_err
tvm.tir.schedule.schedule.ScheduleError: Traceback (most recent call last):
  1: tvm::tir::TracedScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
        at /workspace/tvm/src/tir/schedule/traced_schedule.cc:489
  0: tvm::tir::ConcreteScheduleNode::ComputeAt(tvm::tir::BlockRV const&, tvm::tir::LoopRV const&, bool, int)
        at /workspace/tvm/src/tir/schedule/concrete_schedule.cc:790
ScheduleError: An error occurred in the schedule primitive 'compute-at'.
The IR with diagnostic is:
# from tvm.script import ir as I
# from tvm.script import tir as T

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(var_reshape240: T.handle, var_adaptive_pool_avg: T.handle):
        T.func_attr({"tir.noalias": T.bool(True)})
        reshape240 = T.match_buffer(var_reshape240, (T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16")
        adaptive_pool_avg = T.match_buffer(var_adaptive_pool_avg, (T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        # tir.Block#0
        with T.block("root"):
        ^^^^^^^^^^^^^^^^^^^^^
            T.reads()
            ^^^^^^^^^
            T.writes()
            ^^^^^^^^^^
            adaptive_pool_sum_shared = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16", scope="shared")
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0 in range(T.int64(2)):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax1 in range(T.int64(1024)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax2 in range(T.int64(12)):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax3 in range(T.int64(30)):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            for ax4 in range(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12)):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                ax2_1 = T.int64()
                                ^^^^^^^^^^^^^^^^^
                                for ax5 in range(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                    ax3_1 = T.int64()
                                    ^^^^^^^^^^^^^^^^^
                                    with T.block("adaptive_pool_sum"):
                                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v0 = T.axis.spatial(T.int64(2), ax0)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v1 = T.axis.spatial(T.int64(1024), ax1)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v2 = T.axis.spatial(T.int64(12), ax2)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v3 = T.axis.spatial(T.int64(30), ax3)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v4 = T.axis.reduce(T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), ax4)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        v5 = T.axis.reduce(T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30), ax5)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        T.writes(adaptive_pool_sum_shared[v0, v1, v2, v3])
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        with T.init():
                                        ^^^^^^^^^^^^^^
                                            adaptive_pool_sum_shared[v0, v1, v2, v3] = T.float16(0)
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                        adaptive_pool_sum_shared[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
            for ax0_ax1_ax2_ax3_fused in T.thread_binding(T.int64(737280), thread="blockIdx.x"):
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                for ax4 in range(T.int64(1)):
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                    for ax5_1 in T.thread_binding(T.int64(256), thread="threadIdx.x"):
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                        for ax5_0 in T.serial(T.int64(1), annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                            with T.block("adaptive_pool_avg"):
                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v0 = T.axis.spatial(T.int64(2), ax0_ax1_ax2_ax3_fused // T.int64(368640))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v1 = T.axis.spatial(T.int64(1024), ax0_ax1_ax2_ax3_fused % T.int64(368640) // T.int64(360))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v2 = T.axis.spatial(T.int64(12), ax0_ax1_ax2_ax3_fused % T.int64(360) // T.int64(30))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v3 = T.axis.spatial(T.int64(30), ax0_ax1_ax2_ax3_fused % T.int64(30))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v4 = T.axis.spatial(T.int64(1), ax4)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                v5 = T.axis.spatial(T.int64(1), ax5_0 * T.int64(256) + ax5_1)
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.where(ax5_0 * T.int64(256) + ax5_1 < T.int64(1))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.reads(adaptive_pool_sum_shared[v0, v1, v2, v3])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.writes(adaptive_pool_avg[v0, v1, v2, v3])
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
                                adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum_shared[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v3) * T.int64(40) // T.int64(30)))
                                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Error message: The scope tir.Block#0 is not a stage pipeline.
Definition of a scope that is a stage pipeline:
- The region cover property holds for every of its child blocks
- No write-after-read dependency or opaque dependency,
- only read-after-write and write-after-write are allowed
- All the statements in the scope are schedulable statements, i.e. Block and For`

I also printed out the generated TIR by _DebugDump in compile pass, and I compared it with the code in the test code you provided, I found they are almost the same. The difference is in my TIR, there are some cast to int64 like T.Cast("int64", v_ax3):

`@T.prim_func(private=True)
    def adaptive_avg_pool2d(reshape72: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
                with T.block("adaptive_pool_sum"):
                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                    T.reads(reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                    T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                    with T.init():
                        adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape72[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12), (T.Cast("int64", v_ax2) * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - T.Cast("int64", v_ax2) * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30), (T.Cast("int64", v_ax3) * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - T.Cast("int64", v_ax3) * T.int64(40) // T.int64(30)))`

And in my new error message, the TIR printed out in terminal shows there are still undefined variable usage for ax2_1 and ax3_1. And the TIR is different from what I got from _DebugDump above.

Could you help look into further on this issue? Many thanks!

@wrongtest-intellif
Copy link
Contributor Author

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg

@SimonSongg
Copy link

SimonSongg commented Oct 22, 2024

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg

Hi,

def _build_default():
    def build(mod: IRModule, args: "CompileArgs", pipeline=None):
        output = args.output
        if output.suffix in [".tar", ".lib"]:
            system_lib = True
        elif output.suffix in [".so", ".dylib", ".dll"]:
            system_lib = False
        else:
            logger.warning("Unknown output suffix: %s. Assuming shared library.", output.suffix)
            system_lib = False
        mod = _add_system_lib_prefix(mod, args.system_lib_prefix, is_system_lib=system_lib)
        relax.build(
            mod,
            target=args.target,
            pipeline=pipeline,
            system_lib=system_lib,
        ).export_library(
            str(output),
        )

    return build

Here is the code mlc-llm used to compile model.

The pipeline is:

@tvm.transform.module_pass(opt_level=0)
    def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.IRModule:
        seq = tvm.transform.Sequential(
            [
                # Phase 0. Add additional information for compilation and remove unused Relax func
                DispatchKVCacheCreation(target, flashinfer, metadata),
                # AttachSoftmaxWithTemperature(target),
                AttachVariableBounds(variable_bounds),
                AttachCUDAGraphSymbolicCaptureHints(cuda_graph_symbolic_capture_hints),
                AttachLogitProcessFunc(target),
                AttachAdditionalPrimFuncs(additional_tirs),
                AttachAllocEmbeddingTensorFunc(metadata),
                # AttachGPUSamplingFunc(target, variable_bounds),
                AttachSpecDecodeAuxFuncs(tensor_parallel_shards),
                AttachMemoryPlanAttr(),
                tvm.tir.transform.BindTarget(tvm.target.Target.current(allow_none=False)),
                _DebugDump("debug-phase0.py", debug_dump, show_meta=False),
                # Phase 1. Passes on high-level operator graph
                _LogProgress("Running TVM Relax graph-level optimizations"),
                FuseFTDequantizeEpilogue(),
                FuseDequantizeTranspose(),
                CublasDispatch() if cublas_gemm else tvm.transform.Sequential([]),
                FuseAddRMSNorm(target=target),
                FuseTransposeMatmul(),
                _DebugDump("debug-phase1.py", debug_dump, show_meta=False),
                # Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
                _LogProgress("Lowering to TVM TIR kernels"),
                tvm.relax.backend.DispatchSortScan(),
                tvm.relax.transform.LegalizeOps(),
                tvm.relax.transform.AnnotateTIROpPattern(),
                tvm.relax.transform.FoldConstant(),
                tvm.relax.transform.FuseOps(),
                tvm.relax.transform.FuseTIR(),
                # _DebugDump("debug-phase2.py", debug_dump, show_meta=False),
                # Phase 3. Passes on TIR
                _LogProgress("Running TVM TIR-level optimizations"),
                FuseDequantizeMatmulEwise(),
                FuseDequantizeTake(),
                tvm.relax.transform.DeadCodeElimination(),
                CleanUpTIRAttrs(["op_pattern"]),
                # _DebugDump("debug-phase3.py", debug_dump, show_meta=False),
                # Phase 4. Low-level Optimizations
                _LogProgress("Running TVM Dlight low-level optimizations"),
                LowBatchGemvSpecialize(),
                dl.ApplyDefaultSchedule(
                    dl.gpu.Matmul(),
                    dl.gpu.GEMV(),
                    dl.gpu.Reduction(),
                    dl.gpu.GeneralReduction(),
                    dl.gpu.Fallback(),
                ),
                _DebugDump("debug-phase4.py", debug_dump, show_meta=False),
                _LogProgress("Lowering to VM bytecode"),
                LiftTIRGlobalBufferAlloc(),
                (
                    tvm.tir.transform.ForceNarrowIndexToInt32()
                    if target.kind.name != "cuda"
                    else tvm.transform.Sequential([])
                ),
                ScatterTupleGetItem(),
                tvm.relax.transform.RewriteDataflowReshape(),
                tvm.relax.transform.ToNonDataflow(),
                tvm.relax.transform.RemovePurityChecking(),
                tvm.relax.transform.CallTIRRewrite(),
                (
                    tvm.relax.transform.IPCAllReduceRewrite(allreduce_strategy)
                    if allreduce_strategy != IPCAllReduceStrategyType.NONE
                    else tvm.transform.Sequential([])
                ),
                tvm.relax.transform.StaticPlanBlockMemory(),
                AttachMetadataWithMemoryUsage(metadata),
                tvm.relax.transform.RewriteCUDAGraph(),
                tvm.relax.transform.LowerGPUIPCAllocStorage(),
                tvm.relax.transform.LowerAllocTensor(),
                tvm.relax.transform.KillAfterLastUse(),
                tvm.relax.transform.VMBuiltinLower(),
                tvm.relax.transform.VMShapeLower(),
                tvm.relax.transform.AttachGlobalSymbol(),
                _DebugDump("debug-final.py", debug_dump, show_meta=False),
                _LogProgress("Compiling external modules"),
                tvm.relax.transform.AttachExternModules(ext_mods),
                _LogProgress("Compilation complete! Exporting to disk"),
            ]
        )
        mod = seq(mod)
        return mod

Looks like the error occurs at dl.gpu.GeneralReduction()

Thanks!

@tqchen tqchen requested a review from Hzfengsy October 22, 2024 17:16
@SimonSongg
Copy link

SimonSongg commented Oct 24, 2024

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg

Hi, sorry for bothering @wrongtest-intellif , but I have a new finding.

In tvm/dlight/gpu/general_reduction.py, I printed out the ir module by:

sch = tir.Schedule(func)
        print("===========================")
        sch.show()
        block_infos = normalize_prim_func(sch)
        print("===========================")
        sch.show()

And I found normalize_prim_func() change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!!

Before normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
                with T.block("adaptive_pool_sum"):
                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                    T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                    T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                    with T.init():
                        adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))

After normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
            ax2_1 = T.int64()  ######## HERE ########
            ax3_1 = T.int64()  ######## HERE ########
            with T.block("adaptive_pool_sum"):
                v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5])
                T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
                T.writes(adaptive_pool_sum[v0, v1, v2, v3])
                with T.init():
                    adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0)
                adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v0, v1, v2, v3])
                T.writes(adaptive_pool_avg[v0, v1, v2, v3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))

Could you give some clues about how to fix this bug? Thanks you very much!

@wrongtest-intellif
Copy link
Contributor Author

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg

Hi, sorry for bothering @wrongtest-intellif , but I have a new finding.

In tvm/dlight/gpu/general_reduction.py, I printed out the ir module by:

sch = tir.Schedule(func)
        print("===========================")
        sch.show()
        block_infos = normalize_prim_func(sch)
        print("===========================")
        sch.show()

And I found normalize_prim_func() change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!!

Before normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
                with T.block("adaptive_pool_sum"):
                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                    T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                    T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                    with T.init():
                        adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))

After normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
            ax2_1 = T.int64()  ######## HERE ########
            ax3_1 = T.int64()  ######## HERE ########
            with T.block("adaptive_pool_sum"):
                v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5])
                T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
                T.writes(adaptive_pool_sum[v0, v1, v2, v3])
                with T.init():
                    adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0)
                adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v0, v1, v2, v3])
                T.writes(adaptive_pool_avg[v0, v1, v2, v3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))

Could you give some clues about how to fix this bug? Thanks you very much!

This is because except the primfunc creation, the schedule system and auto-schedule rules also has potential issues to not correctly take loop carried dependency into consideration. Since they are generally developed to optimize static shape workloads.

We are trying to find proper way out for such workloads.

@SimonSongg
Copy link

When I tried to reproduce the error with your test code, I met following error, which is somehow similar but not exactly the same to what I met when compiling the model.

Hi, thanks for the follow up information. Could you help to provide the compile script from relax? @SimonSongg

Hi, sorry for bothering @wrongtest-intellif , but I have a new finding.
In tvm/dlight/gpu/general_reduction.py, I printed out the ir module by:

sch = tir.Schedule(func)
        print("===========================")
        sch.show()
        block_infos = normalize_prim_func(sch)
        print("===========================")
        sch.show()

And I found normalize_prim_func() change the prim_func, which made the problem https://discuss.tvm.apache.org/t/compilation-error-for-adaptive-avg-pool2d-relax-op-in-mlc-llm/17784 solved by you occurred again!!
Before normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            for rv0, rv1 in T.grid(T.Select((ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2 * T.int64(16) // T.int64(12), T.Select((ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3 * T.int64(40) // T.int64(30)):
                with T.block("adaptive_pool_sum"):
                    v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
                    T.reads(reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1])
                    T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                    with T.init():
                        adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float16(0)
                    adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + reshape240[v_ax0, v_ax1, v_ax2 * T.int64(16) // T.int64(12) + v_rv0, v_ax3 * T.int64(40) // T.int64(30) + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float16", T.Select((v_ax2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12), (v_ax2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v_ax2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v_ax3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30), (v_ax3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v_ax3 * T.int64(40) // T.int64(30)))

After normalize_prim_func:

@I.ir_module
class Module:
    @T.prim_func(private=True)
    def main(reshape240: T.Buffer((T.int64(2), T.int64(1024), T.int64(16), T.int64(40)), "float16"), adaptive_pool_avg: T.Buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")):
        T.func_attr({"tir.noalias": T.bool(True)})
        # with T.block("root"):
        adaptive_pool_sum = T.alloc_buffer((T.int64(2), T.int64(1024), T.int64(12), T.int64(30)), "float16")
        for ax0, ax1, ax2, ax3, ax4, ax5 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30), T.Select((ax2_1 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12), (ax2_1 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - ax2_1 * T.int64(16) // T.int64(12), T.Select((ax3_1 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30), (ax3_1 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - ax3_1 * T.int64(40) // T.int64(30)):
            ax2_1 = T.int64()  ######## HERE ########
            ax3_1 = T.int64()  ######## HERE ########
            with T.block("adaptive_pool_sum"):
                v0, v1, v2, v3, v4, v5 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, ax4, ax5])
                T.reads(reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5])
                T.writes(adaptive_pool_sum[v0, v1, v2, v3])
                with T.init():
                    adaptive_pool_sum[v0, v1, v2, v3] = T.float16(0)
                adaptive_pool_sum[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] + reshape240[v0, v1, v2 * T.int64(16) // T.int64(12) + v4, v3 * T.int64(40) // T.int64(30) + v5]
        for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(1024), T.int64(12), T.int64(30)):
            with T.block("adaptive_pool_avg"):
                v0, v1, v2, v3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v0, v1, v2, v3])
                T.writes(adaptive_pool_avg[v0, v1, v2, v3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v0, v1, v2, v3] = adaptive_pool_sum[v0, v1, v2, v3] / (T.Cast("float16", T.Select((v2 * T.int64(4) + T.int64(4)) % T.int64(12) == T.int64(0), (v2 * T.int64(16) + T.int64(16)) // T.int64(12), (v2 * T.int64(16) + T.int64(16)) // T.int64(12) + T.int64(1)) - v2 * T.int64(16) // T.int64(12)) * T.Cast("float16", T.Select((v3 * T.int64(10) + T.int64(10)) % T.int64(30) == T.int64(0), (v3 * T.int64(40) + T.int64(40)) // T.int64(30), (v3 * T.int64(40) + T.int64(40)) // T.int64(30) + T.int64(1)) - v3 * T.int64(40) // T.int64(30)))

Could you give some clues about how to fix this bug? Thanks you very much!

This is because except the primfunc creation, the schedule system and auto-schedule rules also has potential issues to not correctly take loop carried dependency into consideration. Since they are generally developed to optimize static shape workloads.

We are trying to find proper way out for such workloads.

Thank you very much for your reply!

@tqchen
Copy link
Member

tqchen commented Nov 4, 2024

cc @Hzfengsy can you help to review this PR

@wrongtest-intellif wrongtest-intellif marked this pull request as draft November 5, 2024 02:59
@wrongtest-intellif wrongtest-intellif force-pushed the fix_create_primfunc_with_loop_carried_deps branch 2 times, most recently from 4fe1133 to 97a2799 Compare November 9, 2024 03:59
@wrongtest-intellif wrongtest-intellif force-pushed the fix_create_primfunc_with_loop_carried_deps branch from 97a2799 to 7f71550 Compare November 13, 2024 10:22
@wrongtest-intellif wrongtest-intellif changed the title [TE][CreatePrimFunc] Fix loop carried dependency case [TE][CreatePrimFunc] Fix loop carried dependency case with nested block levels Nov 14, 2024
@wrongtest-intellif
Copy link
Contributor Author

Here is some updates for new change. Because TE could define axis with it's domain depend on previous axes. It is a problem to convert such compute op to one single block since the block iter vars should be insensitive to their relative positions defined.

It seems better to create nested block levels to represent such workloads. And ensure each level's block iter vars are independent. For the adaptive pooling case in the context, changed create_primfunc would generate as below

  1. outer block only define spacial iterations and related loops
  2. inner block treat outer spacial vars as free, and define reduce iterations and trivial spacial iterations.

This decomposition could ensure independency for block vars.

@T.prim_func
def tir_workload(x: T.Buffer((1, 1024, 16, 40), "float32"), adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32")):
        T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
        adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
        for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
            with T.block("adaptive_pool_sum_1"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(x[v_ax0, v_ax1, v_ax2 * 16 // 12:v_ax2 * 16 // 12 + ((v_ax2 % 3 * 4 + 16) // 12 + 1), v_ax3 * 40 // 30:v_ax3 * 40 // 30 + ((v_ax3 % 3 * 10 + 40) // 30 + 1)])
                T.writes(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                for rv0, rv1 in T.grid(T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12, T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30):
                    with T.block("adaptive_pool_sum"):
                        v_ax0_1 = T.axis.spatial((v_ax0, v_ax0 + 1), v_ax0)
                        v_ax1_1 = T.axis.spatial((v_ax1, v_ax1 + 1), v_ax1)
                        v_ax2_1 = T.axis.spatial((v_ax2, v_ax2 + 1), v_ax2)
                        v_ax3_1 = T.axis.spatial((v_ax3, v_ax3 + 1), v_ax3)
                        v_rv0, v_rv1 = T.axis.remap("RR", [rv0, rv1])
                        T.reads(x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1])
                        T.writes(adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1])
                        with T.init():
                            adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = T.float32(0.0)
                        adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] = adaptive_pool_sum[v_ax0_1, v_ax1_1, v_ax2_1, v_ax3_1] + x[v_ax0_1, v_ax1_1, v_ax2_1 * 16 // 12 + v_rv0, v_ax3_1 * 40 // 30 + v_rv1]
        for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
            with T.block("adaptive_pool_avg"):
                v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
                T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
                T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
                T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
                adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30))

I also try some tune methods:

  1. dlight's general_reduction, failed. I think that is because the outer level of block is not reduction block now.
  2. metaschedule with llvm or cuda target. If I disable the DisallowDynamicLoop, and force skip the inner block in space generation with f_block_filter. It could produce correct and optimized results. The trace example could be
def apply_trace(sch: tir.Schedule) -> None:
  b0 = sch.get_block(name="adaptive_pool_sum_l1", func_name="main")
  b1 = sch.get_block(name="adaptive_pool_avg", func_name="main")
  b2 = sch.get_block(name="root", func_name="main")
  sch.unannotate(block_or_loop=b1, ann_key="schedule_rule")
  v3 = sch.sample_categorical(candidates=[0, 16, 64, 512, 1024], probs=[0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001, 0.20000000000000001], decision=4)
  sch.annotate(block_or_loop=b2, ann_key="meta_schedule.unroll_explicit", ann_val=v3)
  l4, l5, l6, l7 = sch.get_loops(block=b1)
  l8 = sch.fuse(l4, l5, l6, l7, preserve_unit_iters=True)
  l9, l10, l11 = sch.split(loop=l8, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l10, l11, l9)
  sch.bind(loop=l10, thread_axis="blockIdx.x")
  sch.bind(loop=l11, thread_axis="threadIdx.x")
  l12, l13, l14, l15 = sch.get_loops(block=b0)
  l16 = sch.fuse(l12, l13, l14, l15, preserve_unit_iters=True)
  l17, l18, l19 = sch.split(loop=l16, factors=[None, 256, 1024], preserve_unit_iters=True, disable_predication=False)
  sch.reorder(l18, l19, l17)
  sch.bind(loop=l18, thread_axis="blockIdx.x")
  sch.bind(loop=l19, thread_axis="threadIdx.x")
  sch.enter_postproc()
  b20 = sch.get_block(name="root", func_name="main")
  sch.unannotate(block_or_loop=b20, ann_key="meta_schedule.unroll_explicit")
  b21, b22 = sch.get_child_blocks(b20)
  l23, l24, l25 = sch.get_loops(block=b21)
  l26, l27, l28 = sch.get_loops(block=b22)
  b29 = sch.get_block(name="adaptive_pool_sum", func_name="main")
  l30, l31 = sch.get_loops(block=b29)
  b32 = sch.decompose_reduction(block=b29, loop=l30)

@wrongtest-intellif wrongtest-intellif marked this pull request as ready for review November 14, 2024 02:32
@Hzfengsy Hzfengsy merged commit 370ec6a into apache:main Nov 14, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants