diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index 74200526b699..44e338cbe8ca 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -314,6 +314,15 @@ class StorageAllocatorBaseVisitor : public ExprVisitor { SetTokens(binding->var.get(), token_map_[binding->value.get()]); } + void VisitBindingBlock_(const DataflowBlockNode* block) override { + // We maintain a block stack for token allocation-site and use-site check. + block_stack_.push_back(block); + ExprVisitor::VisitBindingBlock_(block); + ICHECK(!block_stack_.empty()); + ICHECK(block_stack_.back() == block); + block_stack_.pop_back(); + } + void VisitExpr_(const TupleNode* tuple) final { Array tokens; tokens.reserve(tuple->fields.size()); diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 1150827b19f9..28015f0eecff 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1504,5 +1504,46 @@ def main() -> R.Tensor((128,), dtype="float32"): tvm.ir.assert_structural_equal(after, Expected) +def test_with_dataflow(): + @I.ir_module + class Before: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Before + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), runtime_device_index=0 + ) + _: R.Tuple() = cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + @I.ir_module + class Expected: + @T.prim_func + def exp(A: T.handle, B: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"): + cls = Expected + with R.dataflow(): + alloc: R.Tensor((10,), dtype="float32") = R.builtin.alloc_tensor( + R.shape([10]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.exp(x, alloc) + gv: R.Tensor((10,), dtype="float32") = alloc + R.output(gv) + return gv + + after = relax.transform.StaticPlanBlockMemory()(Before) + tvm.ir.assert_structural_equal(after, Expected) + + if __name__ == "__main__": tvm.testing.main()