Skip to content

Issue with unrolling scf mixed with inlining recursive functions. #405

Open
@weinbe58

Description

@weinbe58

I'm trying to mix scf.unroll rewrites with inlining but I am getting some odd behavior with recursive functions. Here this pass when run multiple times eventually folds the function to the incorrect result.

from dataclasses import dataclass, field
from typing import Callable
from kirin.ir.method import Method
from kirin.prelude import structural
from kirin.dialects import ilist, scf
from kirin import rewrite, ir

from kirin.passes import Pass, HintConst, TypeInfer
from kirin.rewrite.abc import RewriteResult



@dataclass
class UnrollPass(Pass):
    """Attempt to unroll loops and inline functions"""

    inline_heuristic: Callable[[ir.Statement], bool] = field(default=lambda _:True, kw_only=True)

    def __post_init__(self):
        self.typeinfer = TypeInfer(self.dialects)

    def unsafe_run(self, mt: Method) -> RewriteResult:
        result = RewriteResult()
        
        result =HintConst(mt.dialects, no_raise=self.no_raise).unsafe_run(mt).join(result)
    
        mt.print()
        
        rule = rewrite.Chain(
            rewrite.ConstantFold(),
            rewrite.Call2Invoke(),
            rewrite.InlineGetField(),
            rewrite.InlineGetItem(),
            rewrite.DeadCodeElimination(),
            rewrite.CommonSubexpressionElimination(),
        )
        result = rewrite.Fixpoint(rewrite.Walk(rule)).rewrite(mt.code).join(result)
        
        mt.print()

        result = (
            rewrite.Walk(
                rewrite.Chain(
                    scf.unroll.PickIfElse(),
                    scf.unroll.ForLoop(),
                    scf.trim.UnusedYield(),
                )
            )
            .rewrite(mt.code)
            .join(result)
        )


        # run typeinfer again after unroll etc. because we now insert
        # a lot of new nodes, which might have more precise types
        self.typeinfer.unsafe_run(mt)
        result = (
            rewrite.Walk(
                rewrite.Inline(self.inline_heuristic),
            )
            .rewrite(mt.code)
            .join(result)
        )
        result = rewrite.Walk(rewrite.Fixpoint(rewrite.CFGCompactify())).rewrite(mt.code).join(result)
        mt.print()
        return result
        
       
def test_fold_scf_if():
    @structural
    def test_recursion(depth: int) -> int:
        if depth > 0:
            return test_recursion(depth - 1) + 1
        else:
            return 1
                
    @structural
    def entry():
        return test_recursion(3)

    UnrollPass(entry.dialects)(entry)
    UnrollPass(entry.dialects)(entry)
    UnrollPass(entry.dialects)(entry)
    UnrollPass(entry.dialects)(entry)
    UnrollPass(entry.dialects)(entry)

    # entry.print()



if __name__ == "__main__":
    test_fold_scf_if()
        

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions