Skip to content

Commit

Permalink
Fix InternalError in StaticPlanBlockMemory when visiting DataflowBloc…
Browse files Browse the repository at this point in the history
…kNode (#17501)

This PR fixes an internal error #17488 

This error happens because the visitor class StorageAllocatorBaseVisitor
does not correctly handle DataflowBlockNode instances.
Specifically, the VisitBindingBlock_ method is not overridden
for DataflowBlockNode, leading to an empty block_stack_
when it is expected to contain the current block.

To fix this issue, we need to override the VisitBindingBlock_
method for const DataflowBlockNode* in the
StorageAllocatorBaseVisitor class. By doing so, we ensure that
the block_stack_ is correctly managed when visiting dataflow
blocks, similar to how it is managed for regular binding blocks.
  • Loading branch information
Thrsu authored Nov 13, 2024
1 parent c7e9292 commit 3d96623
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor {
SetTokens(binding->var.get(), token_map_[binding->value.get()]);
}

void VisitBindingBlock_(const DataflowBlockNode* block) override {
// We maintain a block stack for token allocation-site and use-site check.
block_stack_.push_back(block);
ExprVisitor::VisitBindingBlock_(block);
ICHECK(!block_stack_.empty());
ICHECK(block_stack_.back() == block);
block_stack_.pop_back();
}

void VisitExpr_(const TupleNode* tuple) final {
Array<Tokens> tokens;
tokens.reserve(tuple->fields.size());
Expand Down
41 changes: 41 additions & 0 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,5 +1504,46 @@ def main() -> R.Tensor((128,), dtype="float32"):
tvm.ir.assert_structural_equal(after, Expected)


def test_with_dataflow():
@I.ir_module
class Before:
@T.prim_func
def exp(A: T.handle, B: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Before
with R.dataflow():
alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
R.shape([10]), R.dtype("float32"), runtime_device_index=0
)
_: R.Tuple() = cls.exp(x, alloc)
gv: R.Tensor((10,), dtype="float32") = alloc
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func
def exp(A: T.handle, B: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Expected
with R.dataflow():
alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor(
R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global")
)
cls.exp(x, alloc)
gv: R.Tensor((10,), dtype="float32") = alloc
R.output(gv)
return gv

after = relax.transform.StaticPlanBlockMemory()(Before)
tvm.ir.assert_structural_equal(after, Expected)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 3d96623

Please sign in to comment.