Open
Description
I'm trying to mix scf.unroll rewrites with inlining but I am getting some odd behavior with recursive functions. Here this pass when run multiple times eventually folds the function to the incorrect result.
from dataclasses import dataclass, field
from typing import Callable
from kirin.ir.method import Method
from kirin.prelude import structural
from kirin.dialects import ilist, scf
from kirin import rewrite, ir
from kirin.passes import Pass, HintConst, TypeInfer
from kirin.rewrite.abc import RewriteResult
@dataclass
class UnrollPass(Pass):
"""Attempt to unroll loops and inline functions"""
inline_heuristic: Callable[[ir.Statement], bool] = field(default=lambda _:True, kw_only=True)
def __post_init__(self):
self.typeinfer = TypeInfer(self.dialects)
def unsafe_run(self, mt: Method) -> RewriteResult:
result = RewriteResult()
result =HintConst(mt.dialects, no_raise=self.no_raise).unsafe_run(mt).join(result)
mt.print()
rule = rewrite.Chain(
rewrite.ConstantFold(),
rewrite.Call2Invoke(),
rewrite.InlineGetField(),
rewrite.InlineGetItem(),
rewrite.DeadCodeElimination(),
rewrite.CommonSubexpressionElimination(),
)
result = rewrite.Fixpoint(rewrite.Walk(rule)).rewrite(mt.code).join(result)
mt.print()
result = (
rewrite.Walk(
rewrite.Chain(
scf.unroll.PickIfElse(),
scf.unroll.ForLoop(),
scf.trim.UnusedYield(),
)
)
.rewrite(mt.code)
.join(result)
)
# run typeinfer again after unroll etc. because we now insert
# a lot of new nodes, which might have more precise types
self.typeinfer.unsafe_run(mt)
result = (
rewrite.Walk(
rewrite.Inline(self.inline_heuristic),
)
.rewrite(mt.code)
.join(result)
)
result = rewrite.Walk(rewrite.Fixpoint(rewrite.CFGCompactify())).rewrite(mt.code).join(result)
mt.print()
return result
def test_fold_scf_if():
@structural
def test_recursion(depth: int) -> int:
if depth > 0:
return test_recursion(depth - 1) + 1
else:
return 1
@structural
def entry():
return test_recursion(3)
UnrollPass(entry.dialects)(entry)
UnrollPass(entry.dialects)(entry)
UnrollPass(entry.dialects)(entry)
UnrollPass(entry.dialects)(entry)
UnrollPass(entry.dialects)(entry)
# entry.print()
if __name__ == "__main__":
test_fold_scf_if()
Metadata
Metadata
Assignees
Labels
No labels