Skip to content

Commit

Permalink
fix loop carried dependency case in create_primfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif committed Oct 19, 2024
1 parent ff0b07b commit 29c06ed
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,9 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in

Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits());
return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
auto new_var = Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
analyzer->Bind(iter_var->var, new_var);
return new_var;
});

// Step 2. Generate block bodies.
Expand Down Expand Up @@ -398,8 +400,8 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
// Step 3. Generate loop nesting.
for (size_t i = axes.size(); i > 0; --i) {
const IterVar& axis = axes[i - 1];
PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
PrimExpr dom_min = analyzer->Simplify(info->transformer(axis->dom->min));
PrimExpr dom_extent = analyzer->Simplify(info->transformer(axis->dom->extent));
const Var& loop_var = Downcast<Var>(bindings[i - 1]);
body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
}
Expand Down
34 changes: 34 additions & 0 deletions tests/python/te/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,5 +887,39 @@ def te_workload():
_check_workload(te_workload, tir_workload)


def test_loop_aware():
@T.prim_func
def tir_workload(
x: T.Buffer((1, 1024, 16, 40), "float32"),
adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
# fmt: off
adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
for rv0, rv1 in T.grid(T.Select((ax2 * 4 + 4) % 12 == 0, (ax2 * 16 + 16) // 12, (ax2 * 16 + 16) // 12 + 1) - ax2 * 16 // 12, T.Select((ax3 * 10 + 10) % 30 == 0, (ax3 * 40 + 40) // 30, (ax3 * 40 + 40) // 30 + 1) - ax3 * 40 // 30):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0.0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + x[v_ax0, v_ax1, v_ax2 * 16 // 12 + v_rv0, v_ax3 * 40 // 30 + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30))
# fmt: on

def te_workload():
x = te.placeholder([1, 1024, 16, 40], "float32", "x")
y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg")
f = te.create_prim_func([x, y])
return [x, y]

_check_workload(te_workload, tir_workload)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 29c06ed

Please sign in to comment.