|
| 1 | +import ast |
| 2 | + |
| 3 | +import rich |
| 4 | + |
| 5 | +from kirin import ir, decl, types, interp, lowering, exceptions |
| 6 | + |
| 7 | +dialect = ir.Dialect("debug") |
| 8 | + |
| 9 | + |
| 10 | +class InfoLowering(ir.FromPythonCall): |
| 11 | + |
| 12 | + def lower( |
| 13 | + self, stmt: type, state: lowering.LoweringState, node: ast.Call |
| 14 | + ) -> lowering.Result: |
| 15 | + if len(node.args) == 0: |
| 16 | + raise exceptions.DialectLoweringError( |
| 17 | + "info() requires at least one argument" |
| 18 | + ) |
| 19 | + |
| 20 | + msg = state.visit(node.args[0]).expect_one() |
| 21 | + if len(node.args) > 1: |
| 22 | + inputs = tuple(state.visit(arg).expect_one() for arg in node.args[1:]) |
| 23 | + else: |
| 24 | + inputs = () |
| 25 | + return lowering.Result(state.append_stmt(Info(msg=msg, inputs=inputs))) |
| 26 | + |
| 27 | + |
| 28 | +@decl.statement(dialect=dialect) |
| 29 | +class Info(ir.Statement): |
| 30 | + """print debug information. |
| 31 | +
|
| 32 | + This statement is used to print debug information during |
| 33 | + execution. The compiler has freedom to choose how to print |
| 34 | + the information and send it back to the caller. Note that |
| 35 | + in the case of heterogeneous hardware, this may not be printed |
| 36 | + on the same device as the caller but instead being a log. |
| 37 | + """ |
| 38 | + |
| 39 | + traits = frozenset({InfoLowering()}) |
| 40 | + msg: ir.SSAValue = decl.info.argument(types.String) |
| 41 | + inputs: tuple[ir.SSAValue, ...] = decl.info.argument() |
| 42 | + |
| 43 | + |
| 44 | +@lowering.wraps(Info) |
| 45 | +def info(msg: str, *inputs) -> None: ... |
| 46 | + |
| 47 | + |
| 48 | +@dialect.register(key="main") |
| 49 | +class ConcreteMethods(interp.MethodTable): |
| 50 | + |
| 51 | + @interp.impl(Info) |
| 52 | + def info(self, interp: interp.Interpreter, frame: interp.Frame, stmt: Info): |
| 53 | + # print("INFO:", frame.get(stmt.msg)) |
| 54 | + rich.print( |
| 55 | + "[dim]┌───────────────────────────────────────────────────────────────[/dim]" |
| 56 | + ) |
| 57 | + rich.print("[dim]│[/dim] [bold cyan]INFO:[/bold cyan] ", end="", sep="") |
| 58 | + print(frame.get(stmt.msg)) |
| 59 | + for input in stmt.inputs: |
| 60 | + rich.print( |
| 61 | + "[dim]│[/dim] ", |
| 62 | + input.name or "unknown", |
| 63 | + "[dim] = [/dim]", |
| 64 | + end="", |
| 65 | + sep="", |
| 66 | + ) |
| 67 | + print(frame.get(input)) |
| 68 | + rich.print( |
| 69 | + "[dim]└───────────────────────────────────────────────────────────────[/dim]" |
| 70 | + ) |
0 commit comments