Skip to content

Commit e4096e9

Browse files
authored
Merge branch 'main' into david/fidelity-scf
2 parents 0b2620b + 7523980 commit e4096e9

31 files changed

+709
-559
lines changed

src/bloqade/squin/analysis/nsites/impls.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
from kirin import interp
2+
from kirin.dialects import scf
3+
from kirin.dialects.scf.typeinfer import TypeInfer as ScfTypeInfer
24

35
from bloqade.squin import op, wire
46

@@ -78,3 +80,8 @@ def rot(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Rot):
7880
def scale(self, interp: NSitesAnalysis, frame: interp.Frame, stmt: op.stmts.Scale):
7981
op_sites = frame.get(stmt.op)
8082
return (op_sites,)
83+
84+
85+
@scf.dialect.register(key="op.nsites")
86+
class ScfSquinOp(ScfTypeInfer):
87+
pass

src/bloqade/squin/passes/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

src/bloqade/squin/passes/stim.py

Lines changed: 0 additions & 68 deletions
This file was deleted.

src/bloqade/squin/rewrite/__init__.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
from .wire_to_stim import SquinWireToStim as SquinWireToStim
2-
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3-
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
41
from .wrap_analysis import (
52
SitesAttribute as SitesAttribute,
63
AddressAttribute as AddressAttribute,
7-
WrapSquinAnalysis as WrapSquinAnalysis,
8-
)
9-
from .wire_identity_elimination import (
10-
SquinWireIdentityElimination as SquinWireIdentityElimination,
4+
WrapOpSiteAnalysis as WrapOpSiteAnalysis,
5+
WrapAddressAnalysis as WrapAddressAnalysis,
116
)
7+
from .remove_dangling_qubits import RemoveDeadRegister as RemoveDeadRegister
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from kirin import ir
2+
from kirin.rewrite.abc import RewriteRule, RewriteResult
3+
4+
from bloqade.squin import qubit
5+
6+
7+
class RemoveDeadRegister(RewriteRule):
8+
9+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10+
11+
if not isinstance(node, qubit.New):
12+
return RewriteResult()
13+
14+
if bool(node.result.uses):
15+
return RewriteResult()
16+
else:
17+
node.delete()
18+
19+
return RewriteResult(has_done_something=True)

src/bloqade/squin/rewrite/wrap_analysis.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import abstractmethod
12
from dataclasses import dataclass
23

34
from kirin import ir
@@ -40,33 +41,47 @@ def print_impl(self, printer: Printer) -> None:
4041

4142

4243
@dataclass
43-
class WrapSquinAnalysis(RewriteRule):
44+
class WrapAnalysis(RewriteRule):
4445

46+
@abstractmethod
47+
def wrap(self, value: ir.SSAValue) -> bool:
48+
pass
49+
50+
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
51+
has_done_something = any(self.wrap(arg) for arg in node.args)
52+
return RewriteResult(has_done_something=has_done_something)
53+
54+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
55+
has_done_something = any(self.wrap(result) for result in node.results)
56+
return RewriteResult(has_done_something=has_done_something)
57+
58+
59+
@dataclass
60+
class WrapAddressAnalysis(WrapAnalysis):
4561
address_analysis: dict[ir.SSAValue, Address]
46-
op_site_analysis: dict[ir.SSAValue, Sites]
4762

4863
def wrap(self, value: ir.SSAValue) -> bool:
4964
address_analysis_result = self.address_analysis[value]
50-
op_site_analysis_result = self.op_site_analysis[value]
5165

52-
if value.hints.get("address") and value.hints.get("sites"):
66+
if value.hints.get("address") is not None:
5367
return False
54-
else:
55-
value.hints["address"] = AddressAttribute(address_analysis_result)
56-
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
68+
69+
value.hints["address"] = AddressAttribute(address_analysis_result)
5770

5871
return True
5972

60-
def rewrite_Block(self, node: ir.Block) -> RewriteResult:
61-
has_done_something = False
62-
for arg in node.args:
63-
if self.wrap(arg):
64-
has_done_something = True
65-
return RewriteResult(has_done_something=has_done_something)
6673

67-
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
68-
has_done_something = False
69-
for result in node.results:
70-
if self.wrap(result):
71-
has_done_something = True
72-
return RewriteResult(has_done_something=has_done_something)
74+
@dataclass
75+
class WrapOpSiteAnalysis(WrapAnalysis):
76+
77+
op_site_analysis: dict[ir.SSAValue, Sites]
78+
79+
def wrap(self, value: ir.SSAValue) -> bool:
80+
op_site_analysis_result = self.op_site_analysis[value]
81+
82+
if value.hints.get("sites") is not None:
83+
return False
84+
85+
value.hints["sites"] = SitesAttribute(op_site_analysis_result)
86+
87+
return True

src/bloqade/stim/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from . import emit as emit, parse as parse, dialects as dialects
1+
from . import emit as emit, parse as parse, passes as passes, dialects as dialects
22
from .groups import main as main
33
from ._wrappers import (
44
h as h,

src/bloqade/stim/dialects/auxiliary/stmts/const.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def print_impl(self, printer: Printer) -> None:
4747

4848
@statement(dialect=dialect)
4949
class ConstBool(ir.Statement):
50-
"""IR Statement representing a constant float value."""
50+
"""IR Statement representing a constant boolean value."""
5151

5252
name = "constant.bool"
5353
traits = frozenset({ir.Pure(), ir.ConstantLike(), lowering.FromPythonCall()})

src/bloqade/stim/passes/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .squin_to_stim import SquinToStim as SquinToStim
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from dataclasses import dataclass
2+
3+
from kirin.passes import Fold
4+
from kirin.rewrite import (
5+
Walk,
6+
Chain,
7+
Fixpoint,
8+
DeadCodeElimination,
9+
CommonSubexpressionElimination,
10+
)
11+
from kirin.ir.method import Method
12+
from kirin.passes.abc import Pass
13+
from kirin.rewrite.abc import RewriteResult
14+
15+
from bloqade.stim.groups import main as stim_main_group
16+
from bloqade.stim.rewrite import (
17+
SquinWireToStim,
18+
PyConstantToStim,
19+
SquinQubitToStim,
20+
SquinMeasureToStim,
21+
SquinWireIdentityElimination,
22+
)
23+
from bloqade.squin.rewrite import RemoveDeadRegister
24+
25+
26+
@dataclass
27+
class SquinToStim(Pass):
28+
29+
def unsafe_run(self, mt: Method) -> RewriteResult:
30+
fold_pass = Fold(mt.dialects)
31+
# propagate constants
32+
rewrite_result = fold_pass(mt)
33+
34+
# Assume that address analysis and
35+
# wrapping has been done before this pass!
36+
37+
# Wrap Rewrite + SquinToStim can happen w/ standard walk
38+
rewrite_result = (
39+
Walk(
40+
Chain(
41+
SquinQubitToStim(),
42+
SquinWireToStim(),
43+
SquinMeasureToStim(), # reduce duplicated logic, can split out even more rules later
44+
SquinWireIdentityElimination(),
45+
)
46+
)
47+
.rewrite(mt.code)
48+
.join(rewrite_result)
49+
)
50+
51+
# Convert all PyConsts to Stim Constants
52+
rewrite_result = (
53+
Walk(Chain(PyConstantToStim())).rewrite(mt.code).join(rewrite_result)
54+
)
55+
56+
# remove any squin.qubit.new that's left around
57+
## Not considered pure so DCE won't touch it but
58+
## it isn't being used anymore considering everything is a
59+
## stim dialect statement
60+
rewrite_result = (
61+
Fixpoint(
62+
Walk(
63+
Chain(
64+
DeadCodeElimination(),
65+
CommonSubexpressionElimination(),
66+
RemoveDeadRegister(),
67+
)
68+
)
69+
)
70+
.rewrite(mt.code)
71+
.join(rewrite_result)
72+
)
73+
74+
# do program verification here,
75+
# ideally use built-in .verify() to catch any
76+
# incompatible statements as the full rewrite process should not
77+
# leave statements from any other dialects (other than the stim main group)
78+
mt_verification_clone = mt.similar(stim_main_group)
79+
80+
# suggested by Kai, will work for now
81+
for stmt in mt_verification_clone.code.walk():
82+
assert (
83+
stmt.dialect in stim_main_group
84+
), "Statements detected that are not part of the stim dialect, please verify the original code is valid for rewrite!"
85+
86+
return rewrite_result

src/bloqade/stim/rewrite/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .wire_to_stim import SquinWireToStim as SquinWireToStim
2+
from .qubit_to_stim import SquinQubitToStim as SquinQubitToStim
3+
from .squin_measure import SquinMeasureToStim as SquinMeasureToStim
4+
from .py_constant_to_stim import PyConstantToStim as PyConstantToStim
5+
from .wire_identity_elimination import (
6+
SquinWireIdentityElimination as SquinWireIdentityElimination,
7+
)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from kirin import ir
2+
from kirin.dialects import py
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from bloqade.stim.dialects import auxiliary
6+
7+
# py.constant.int -> stim.const.ConstInt
8+
# py.constant.float -> stimt.const.ConstFloat
9+
# py.constant -> stim.const.ConstBool
10+
#
11+
12+
13+
class PyConstantToStim(RewriteRule):
14+
15+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
16+
17+
match node:
18+
case py.constant.Constant():
19+
return self.rewrite_PyConstant(node)
20+
case _:
21+
return RewriteResult()
22+
23+
def rewrite_PyConstant(self, node: py.constant.Constant) -> RewriteResult:
24+
25+
# node.value is a PyAttr, need to get the
26+
# wrapped value out
27+
wrapped_value = node.value.unwrap()
28+
29+
if isinstance(wrapped_value, int):
30+
stim_const = auxiliary.ConstInt(value=wrapped_value)
31+
elif isinstance(wrapped_value, float):
32+
stim_const = auxiliary.ConstFloat(value=wrapped_value)
33+
elif isinstance(wrapped_value, bool):
34+
stim_const = auxiliary.ConstBool(value=wrapped_value)
35+
elif isinstance(wrapped_value, str):
36+
stim_const = auxiliary.ConstStr(value=wrapped_value)
37+
else:
38+
return RewriteResult()
39+
40+
node.replace_by(stim_const)
41+
42+
return RewriteResult(has_done_something=True)

src/bloqade/squin/rewrite/qubit_to_stim.py renamed to src/bloqade/stim/rewrite/qubit_to_stim.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
from kirin.rewrite.abc import RewriteRule, RewriteResult
33

44
from bloqade.squin import op, qubit
5-
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
6-
from bloqade.squin.rewrite.stim_rewrite_util import (
5+
from bloqade.squin.rewrite import AddressAttribute
6+
from bloqade.stim.rewrite.util import (
77
SQUIN_STIM_GATE_MAPPING,
88
rewrite_Control,
99
insert_qubit_idx_from_address,
@@ -41,9 +41,11 @@ def rewrite_Apply_and_Broadcast(
4141
return RewriteResult()
4242

4343
address_attr = stmt.qubits.hints.get("address")
44+
4445
if address_attr is None:
4546
return RewriteResult()
4647

48+
# sometimes you can get a whole AddressReg...
4749
assert isinstance(address_attr, AddressAttribute)
4850
qubit_idx_ssas = insert_qubit_idx_from_address(
4951
address=address_attr, stmt_to_insert_before=stmt

src/bloqade/squin/rewrite/squin_measure.py renamed to src/bloqade/stim/rewrite/squin_measure.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from kirin.rewrite.abc import RewriteRule, RewriteResult
55

66
from bloqade.squin import wire, qubit
7+
from bloqade.squin.rewrite import AddressAttribute
78
from bloqade.stim.dialects import collapse
8-
from bloqade.squin.rewrite.wrap_analysis import AddressAttribute
9-
from bloqade.squin.rewrite.stim_rewrite_util import (
9+
from bloqade.stim.rewrite.util import (
1010
is_measure_result_used,
1111
insert_qubit_idx_from_address,
1212
)

0 commit comments

Comments
 (0)