Skip to content

Commit

Permalink
fix loop carried dependency case in create_primfunc
Browse files Browse the repository at this point in the history
  • Loading branch information
wrongtest-intellif authored and wrongtest committed Oct 22, 2024
1 parent ff0b07b commit ed8bbc7
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 23 deletions.
62 changes: 39 additions & 23 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,19 +183,32 @@ class LayoutFreePlaceholdersNormalizer : public StmtMutator {
BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
const Array<te::Tensor>& tensors, Array<PrimExpr> bindings,
PrimExpr expr_body, CreateFuncInfo* info,
const std::unordered_map<const VarNode*, Var>& loop_var_map,
arith::Analyzer* analyzer) {
std::unordered_map<const VarNode*, Var> block_var_map;

// helper to transform the expr and remap iters to the block domain
auto f_transform_block_domain = [&](const PrimExpr& e) {
return Substitute(info->transformer(e), block_var_map);
};

// helper to transform the expr and remap iters to the loop domain
auto f_transform_loop_domain = [&](const PrimExpr& e) {
return Substitute(info->transformer(e), loop_var_map);
};

// Step 1. Push_back data_par axis and reduce_axis into block_vars.
Array<IterVar> iter_vars;
std::unordered_map<const VarNode*, Var> var_map;
iter_vars.reserve(compute_op->axis.size() + compute_op->reduce_axis.size());
auto f_push_block_vars = [&iter_vars, &var_map, &analyzer](const Array<IterVar>& iters) {
auto f_push_block_vars = [&iter_vars, &block_var_map, &analyzer,
f_transform_loop_domain](const Array<IterVar>& iters) {
for (IterVar iter_var : iters) {
// Create new var
Var new_var("v_" + iter_var->var->name_hint, iter_var->var->dtype);
var_map[iter_var->var.get()] = new_var;
block_var_map[iter_var->var.get()] = new_var;

PrimExpr dom_min = analyzer->Simplify(iter_var->dom->min);
PrimExpr dom_extent = analyzer->Simplify(iter_var->dom->extent);
PrimExpr dom_min = analyzer->Simplify(f_transform_loop_domain(iter_var->dom->min));
PrimExpr dom_extent = analyzer->Simplify(f_transform_loop_domain(iter_var->dom->extent));
iter_vars.push_back(IterVar(Range::FromMinExtent(dom_min, dom_extent), new_var,
iter_var->iter_type, iter_var->thread_tag, iter_var->span));
}
Expand All @@ -222,16 +235,12 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
Array<PrimExpr> indices;
indices.reserve(compute_op->axis.size());
for (const IterVar& iter_var : compute_op->axis) {
auto it = var_map.find(iter_var->var.get());
ICHECK(it != var_map.end());
auto it = block_var_map.find(iter_var->var.get());
ICHECK(it != block_var_map.end());
indices.push_back(it->second);
}

// Step 4. Create block body.
// helper to transform the expr and remap iters to the block domain
auto f_transform_and_remap = [&](const PrimExpr& e) {
return Substitute(info->transformer(e), var_map);
};
String block_name{nullptr};
Optional<Stmt> init = NullOpt;
Stmt body;
Expand All @@ -250,7 +259,7 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
// - A RHS operand is the value to be reduced.
for (int i = 0; i < n_buffers; ++i) {
const PrimExpr& left = BufferLoad(buffers[i], indices);
const PrimExpr& right = analyzer->Simplify(f_transform_and_remap(reduce->source[i]));
const PrimExpr& right = analyzer->Simplify(f_transform_block_domain(reduce->source[i]));
lhs.push_back(left);
rhs.push_back(right);
ICHECK_EQ(left->dtype, right->dtype);
Expand All @@ -270,15 +279,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
// then store the value of the variables into the target buffer positions.
for (int i = 0; i < n_buffers; ++i) {
const Buffer& buffer = buffers[i];
PrimExpr identity = f_transform_and_remap(reduce->combiner->identity_element[i]);
PrimExpr identity = f_transform_block_domain(reduce->combiner->identity_element[i]);
init_stmts.push_back(BufferStore(buffer, identity, indices));
PrimExpr value{nullptr};
if (n_buffers > 1) {
temp_vars.push_back(Var("v_" + buffer->name, PrimType(lhs[i].dtype())));
value = temp_vars.back();
} else {
PrimExpr combined = reduce->combiner.get()->operator()(lhs, rhs)[i];
value = f_transform_and_remap(combined);
value = f_transform_block_domain(combined);
}
body_stmts.push_back(BufferStore(buffer, value, indices));
}
Expand All @@ -288,15 +297,15 @@ BlockRealize GenerateBlockFromTensors(const te::ComputeOp& compute_op,
if (n_buffers > 1) {
// When there are multiple buffers, we wrap the body with LetStmts.
for (int i = n_buffers - 1; i >= 0; --i) {
PrimExpr value = f_transform_and_remap(reduce->combiner.get()->operator()(lhs, rhs)[i]);
PrimExpr value = f_transform_block_domain(reduce->combiner.get()->operator()(lhs, rhs)[i]);
body = LetStmt(temp_vars[i], std::move(value), std::move(body));
}
}
} else {
// Case 2. Data parallel compute
ICHECK_EQ(tensors.size(), 1);
block_name = info->FreshName(tensors[0]->GetNameHint());
const PrimExpr& compute_body = f_transform_and_remap(expr_body);
const PrimExpr& compute_body = f_transform_block_domain(expr_body);
body = BufferStore(info->tensor2buffers[tensors[0]], analyzer->Simplify(compute_body), indices);
}

Expand Down Expand Up @@ -350,12 +359,19 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
// Step 1. Creating loop vars for block bindings.
Array<IterVar> axes = compute_op->axis;
axes.insert(axes.end(), compute_op->reduce_axis.begin(), compute_op->reduce_axis.end());

std::unordered_map<const VarNode*, Var> loop_var_map;
Array<PrimExpr> bindings = axes.Map([&](IterVar iter_var) -> PrimExpr {
int bits = std::max(iter_var->dom->min.dtype().bits(), iter_var->dom->extent.dtype().bits());
return Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
auto new_var = Var(iter_var->var->name_hint, runtime::DataType::Int(bits));
loop_var_map[iter_var->var.get()] = new_var;
return new_var;
});

// helpers to transform the expr and remap iters to the block domain
auto f_transform_loop_domain = [&](const PrimExpr& e) {
return Substitute(info->transformer(e), loop_var_map);
};

// Step 2. Generate block bodies.
Array<Stmt> seq_stmt;
if (compute_op->body[0]->IsInstance<ReduceNode>()) {
Expand Down Expand Up @@ -383,13 +399,13 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
}

seq_stmt.push_back(GenerateBlockFromTensors(compute_op, tensors, bindings, std::move(expr_body),
info, analyzer));
info, loop_var_map, analyzer));
} else {
for (int i = 0; i < compute_op->num_outputs(); ++i) {
const te::Tensor& tensor = compute_op.output(i);
PrimExpr expr_body = compute_op->body[i];
seq_stmt.push_back(GenerateBlockFromTensors(compute_op, {tensor}, bindings,
std::move(expr_body), info, analyzer));
seq_stmt.push_back(GenerateBlockFromTensors(
compute_op, {tensor}, bindings, std::move(expr_body), info, loop_var_map, analyzer));
}
}

Expand All @@ -398,8 +414,8 @@ Stmt GenerateStmtFromCompute(const te::ComputeOp& compute_op, CreateFuncInfo* in
// Step 3. Generate loop nesting.
for (size_t i = axes.size(); i > 0; --i) {
const IterVar& axis = axes[i - 1];
PrimExpr dom_min = analyzer->Simplify(axis->dom->min);
PrimExpr dom_extent = analyzer->Simplify(axis->dom->extent);
PrimExpr dom_min = analyzer->Simplify(f_transform_loop_domain(axis->dom->min));
PrimExpr dom_extent = analyzer->Simplify(f_transform_loop_domain(axis->dom->extent));
const Var& loop_var = Downcast<Var>(bindings[i - 1]);
body = For(loop_var, dom_min, dom_extent, ForKind::kSerial, body);
}
Expand Down
34 changes: 34 additions & 0 deletions tests/python/te/test_te_create_primfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -887,5 +887,39 @@ def te_workload():
_check_workload(te_workload, tir_workload)


def test_loop_aware():
@T.prim_func
def tir_workload(
x: T.Buffer((1, 1024, 16, 40), "float32"),
adaptive_pool_avg: T.Buffer((1, 1024, 12, 30), "float32"),
):
T.func_attr({"tir.noalias": T.bool(True), "global_symbol": "main"})
# fmt: off
adaptive_pool_sum = T.alloc_buffer((1, 1024, 12, 30))
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
for rv0, rv1 in T.grid(T.Select((ax2 * 4 + 4) % 12 == 0, (ax2 * 16 + 16) // 12, (ax2 * 16 + 16) // 12 + 1) - ax2 * 16 // 12, T.Select((ax3 * 10 + 10) % 30 == 0, (ax3 * 40 + 40) // 30, (ax3 * 40 + 40) // 30 + 1) - ax3 * 40 // 30):
with T.block("adaptive_pool_sum"):
v_ax0, v_ax1, v_ax2, v_ax3, v_rv0, v_rv1 = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, rv0, rv1])
with T.init():
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0.0)
adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] + x[v_ax0, v_ax1, v_ax2 * 16 // 12 + v_rv0, v_ax3 * 40 // 30 + v_rv1]
for ax0, ax1, ax2, ax3 in T.grid(1, 1024, 12, 30):
with T.block("adaptive_pool_avg"):
v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
T.reads(adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3])
T.writes(adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3])
T.block_attr({"schedule_rule": "meta_schedule.adaptive_pool_avg"})
adaptive_pool_avg[v_ax0, v_ax1, v_ax2, v_ax3] = adaptive_pool_sum[v_ax0, v_ax1, v_ax2, v_ax3] / (T.Cast("float32", T.Select((v_ax2 * 4 + 4) % 12 == 0, (v_ax2 * 16 + 16) // 12, (v_ax2 * 16 + 16) // 12 + 1) - v_ax2 * 16 // 12) * T.Cast("float32", T.Select((v_ax3 * 10 + 10) % 30 == 0, (v_ax3 * 40 + 40) // 30, (v_ax3 * 40 + 40) // 30 + 1) - v_ax3 * 40 // 30))
# fmt: on

def te_workload():
x = te.placeholder([1, 1024, 16, 40], "float32", "x")
y = topi.nn.adaptive_pool(x, [12, 30], pool_type="avg")
f = te.create_prim_func([x, y])
return [x, y]

_check_workload(te_workload, tir_workload)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit ed8bbc7

Please sign in to comment.