Skip to content

Commit 7040048

Browse files
authored
polish lowering (#362)
change the return of `lowering.run` from `Statement` to `Region` so we can keep the original result as much as possible. We return `ir.Statement` only for `python_function` because a function is mapped to a `Statement`.
1 parent 147f163 commit 7040048

File tree

6 files changed

+28
-17
lines changed

6 files changed

+28
-17
lines changed

src/kirin/ir/dialect.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def wrapper(node: type[T]) -> type[T]:
110110
elif issubclass(node, MethodTable):
111111
if key in self.interps:
112112
raise ValueError(
113-
f"Cannot register {node} to Dialect, key {key} exists"
113+
f"Cannot register {node} to Dialect, key {key} exists in {self}"
114114
)
115115
self.interps[key] = node()
116116
elif issubclass(node, FromPythonAST):

src/kirin/lowering/abc.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import TYPE_CHECKING, Any, Generic, TypeVar, TypeAlias
55
from dataclasses import dataclass
66

7-
from kirin.ir import SSAValue, Statement, DialectGroup
7+
from kirin.ir import Region, SSAValue, Statement, DialectGroup
88

99
from .exception import BuildError
1010

@@ -49,7 +49,7 @@ def run(
4949
lineno_offset: int = 0,
5050
col_offset: int = 0,
5151
compactify: bool = True,
52-
) -> Statement: ...
52+
) -> Region: ...
5353

5454
@abstractmethod
5555
def visit(self, state: State[ASTNodeType], node: ASTNodeType) -> Result:

src/kirin/lowering/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def _push_stmt(self, stmt: StmtType) -> StmtType:
6565
raise BuildError(f"unexpected builtin statement {stmt.name}")
6666
elif stmt.dialect not in self.state.parent.dialects:
6767
raise BuildError(
68-
f"Unsupported dialect `{stmt.dialect.name}` in statement {stmt.name}"
68+
f"Unsupported dialect `{stmt.dialect.name}` from statement {stmt.name}"
6969
)
7070
self.curr_block.stmts.append(stmt)
7171
if stmt.source is None:

src/kirin/lowering/python/glob.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,14 @@ class GlobalExprEval(ast.NodeVisitor):
2323
frame: Frame
2424

2525
def generic_visit(self, node: ast.AST) -> Any:
26+
if isinstance(node, ast.AST):
27+
raise GlobalEvalError(
28+
node,
29+
f"Cannot lower global {node.__class__.__name__} node: {ast.dump(node)}",
30+
)
2631
raise GlobalEvalError(
2732
node,
28-
f"Cannot lower global {node.__class__.__name__} node: {ast.dump(node)}",
33+
f"Unexpected global `{node.__class__.__name__}` node: {repr(node)} is not an AST node",
2934
)
3035

3136
def visit_Name(self, node: ast.Name) -> Any:

src/kirin/lowering/python/lowering.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def python_function(
8181
except Exception:
8282
nonlocals = {}
8383
globals.update(nonlocals)
84-
return self.run(
84+
region = self.run(
8585
ast.parse(source).body[0],
8686
source=source,
8787
globals=globals,
@@ -90,6 +90,13 @@ def python_function(
9090
col_offset=col_offset,
9191
compactify=compactify,
9292
)
93+
if not region.blocks:
94+
raise ValueError("No block generated")
95+
96+
code = region.blocks[0].first_stmt
97+
if code is None:
98+
raise ValueError("No code generated")
99+
return code
93100

94101
def run(
95102
self,
@@ -101,7 +108,7 @@ def run(
101108
lineno_offset: int = 0,
102109
col_offset: int = 0,
103110
compactify: bool = True,
104-
) -> ir.Statement:
111+
) -> ir.Region:
105112
source = source or ast.unparse(stmt)
106113
state = State(
107114
self,
@@ -132,18 +139,12 @@ def run(
132139
raise e
133140

134141
region = frame.curr_region
135-
if not region.blocks:
136-
raise ValueError("No block generated")
137-
138-
code = region.blocks[0].first_stmt
139-
if code is None:
140-
raise ValueError("No code generated")
141142

142143
if compactify:
143144
from kirin.rewrite import Walk, CFGCompactify
144145

145-
Walk(CFGCompactify()).rewrite(code)
146-
return code
146+
Walk(CFGCompactify()).rewrite(region)
147+
return region
147148

148149
def lower_literal(self, state: State[ast.AST], value) -> ir.SSAValue:
149150
return state.lower(ast.Constant(value=value)).expect_one()

src/kirin/lowering/state.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class State(Generic[ASTNodeType]):
5050
"current frame being lowered"
5151

5252
def __repr__(self) -> str:
53-
return f"lowering.State({self.current_frame})"
53+
return f"lowering.State(current_frame={self._current_frame})"
5454

5555
@property
5656
def code(self):
@@ -178,14 +178,19 @@ def frame(
178178
entr_block = entr_block or Block()
179179
region.blocks.append(entr_block)
180180

181+
if self._current_frame is not None:
182+
globals = globals or self.current_frame.globals
183+
else:
184+
globals = globals or {}
185+
181186
frame = Frame(
182187
state=self,
183188
stream=stmts,
184189
curr_region=region or Region(entr_block),
185190
entr_block=entr_block,
186191
curr_block=entr_block,
187192
next_block=next_block or Block(),
188-
globals=globals or self.current_frame.globals,
193+
globals=globals,
189194
capture_callback=capture_callback,
190195
)
191196
self.push_frame(frame)

0 commit comments

Comments
 (0)