From 3d966230caa63b4ad8c3d6c86aaad27d5a8a0918 Mon Sep 17 00:00:00 2001 From: Thrsu <89128704+Thrsu@users.noreply.github.com> Date: Wed, 13 Nov 2024 13:35:42 +0800 Subject: [PATCH] Fix InternalError in StaticPlanBlockMemory when visiting DataflowBlockNode (#17501) This PR fixes an internal error #17488 This error happens because the visitor class StorageAllocatorBaseVisitor does not correctly handle DataflowBlockNode instances. Specifically, the VisitBindingBlock_ method is not overridden for DataflowBlockNode, leading to an empty block_stack_ when it is expected to contain the current block. To fix this issue, we need to override the VisitBindingBlock_ method for const DataflowBlockNode* in the StorageAllocatorBaseVisitor class. By doing so, we ensure that the block_stack_ is correctly managed when visiting dataflow blocks, similar to how it is managed for regular binding blocks. --- .../transform/static_plan_block_memory.cc | 9 ++++ ...test_transform_static_plan_block_memory.py | 41 +++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 74200526b699..44e338cbe8ca 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { SetTokens(binding->var.get(), token_map_[binding->value.get()]); } + void VisitBindingBlock_(const DataflowBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + void VisitExpr_(const TupleNode* tuple) final { Array tokens; tokens.reserve(tuple->fields.size()); diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 1150827b19f9..28015f0eecff 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1504,5 +1504,46 @@ def main() -> R.Tensor((128,), dtype="float32"): tvm.ir.assert_structural_equal(after, Expected) +def test_with_dataflow(): + @I.ir_module + class Before: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Before + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), runtime_device_index=0 + ) + _: R.Tuple() = cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Expected + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main()