Skip to content

Emit a cirq.Circuit from a squin kernel #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9093dd2
Emit cirq.Circuit from squin kernel
david-pl Jun 3, 2025
4d63742
Better typing in emit
david-pl Jun 3, 2025
e124955
Impl all pauli ops
david-pl Jun 3, 2025
d67c2d0
Impl for controls
david-pl Jun 3, 2025
4bf7647
Inline pass for shorthand squin wrappers
david-pl Jun 4, 2025
5d9469b
Fix typing in heuristic
david-pl Jun 4, 2025
dfc076c
Allow passing in custom list of qubits
david-pl Jun 4, 2025
3c7bb64
Emit func invoke as subcircuit
david-pl Jun 4, 2025
c16e99c
Fix setting block arguments
david-pl Jun 4, 2025
7ff58a3
Fix qubit index
david-pl Jun 4, 2025
9f04d7d
Test nested kernels
david-pl Jun 4, 2025
624ad93
Better error message
david-pl Jun 4, 2025
c309ad1
Update qubit index in frame when emitting invoke
david-pl Jun 4, 2025
179d2cf
Properly deal with return values
david-pl Jun 4, 2025
41e799d
Don't add all values to a sub_frame
david-pl Jun 4, 2025
59cf51a
Remove inline pass that is no longer needed
david-pl Jun 4, 2025
894cd53
Add a test when returning qubits from a nested kernel
david-pl Jun 4, 2025
80724b9
Implement missing qubit methods and remove return value from kernel w…
david-pl Jun 4, 2025
0c75266
Error when trying to emit a method with return value
david-pl Jun 4, 2025
ba64cde
Add operator methods with TODO
david-pl Jun 4, 2025
7373a00
Restructure files
david-pl Jun 4, 2025
f5348d5
Start implementing runtime
david-pl Jun 4, 2025
bf9d88a
Projector and Sp, Sn runtime
david-pl Jun 5, 2025
ca73dab
Fix type
david-pl Jun 5, 2025
2cf45fc
Adjoint runtime
david-pl Jun 5, 2025
d7e04cb
U3 runtime
david-pl Jun 5, 2025
e955fd0
Scale runtime
david-pl Jun 5, 2025
78f1a3d
Phase op runtime
david-pl Jun 5, 2025
1970b79
Shift op and reset runtime
david-pl Jun 5, 2025
b2f0e47
PauliString runtime
david-pl Jun 5, 2025
b0e4972
Split out runtime
david-pl Jun 5, 2025
a7a4d4c
Add docstrings and some examples
david-pl Jun 5, 2025
f0aeae6
Merge branch 'main' into david/squin-to-cirq-emit
david-pl Jun 10, 2025
7e56a29
Require single block in emit_invoke
david-pl Jun 10, 2025
f2ffed1
Fix flaky test
david-pl Jun 10, 2025
96955e7
Merge branch 'main' into david/squin-to-cirq-emit
david-pl Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion src/bloqade/squin/cirq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from typing import Any
from typing import Any, Sequence

import cirq
from kirin import ir, types
from kirin.emit import EmitError
from kirin.dialects import func

from . import lowering as lowering
from .. import kernel
from .lowering import Squin
from .emit_circuit import EmitCirq


def load_circuit(
Expand Down Expand Up @@ -87,3 +89,25 @@
dialects=dialects,
code=code,
)


def emit_circuit(
mt: ir.Method,
args=(),
qubits: Sequence[cirq.Qid] | None = None,
) -> cirq.Circuit:

if isinstance(mt.code, func.Function) and not mt.code.signature.output.is_subseteq(
types.NoneType
):
raise EmitError(
"The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
)

emitter = EmitCirq(qubits=qubits)
return emitter.run(mt, args=args)


def dump_circuit(mt: ir.Method, args=(), **kwargs):
circuit = emit_circuit(mt, args=args)
return cirq.to_json(circuit, **kwargs)

Check warning on line 113 in src/bloqade/squin/cirq/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/__init__.py#L112-L113

Added lines #L112 - L113 were not covered by tests
242 changes: 242 additions & 0 deletions src/bloqade/squin/cirq/emit_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
from typing import Sequence
from dataclasses import field, dataclass

import cirq
from kirin import ir, types
from kirin.emit import EmitABC, EmitFrame
from kirin.interp import MethodTable, impl
from kirin.dialects import func
from typing_extensions import Self

from .. import op, qubit, kernel


@dataclass
class EmitCirqFrame(EmitFrame):
qubit_index: int = 0
qubits: Sequence[cirq.Qid] | None = None
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)


def _default_kernel():
return kernel


@dataclass
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
keys = ["emit.cirq", "main"]
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
void = cirq.Circuit()
qubits: Sequence[cirq.Qid] | None = None

def initialize(self) -> Self:
return super().initialize()

def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
) -> EmitCirqFrame:
return EmitCirqFrame(
code, has_parent_access=has_parent_access, qubits=self.qubits
)

def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
return self.run_callable(method.code, args)

def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
for stmt in block.stmts:
result = self.eval_stmt(frame, stmt)
if isinstance(result, tuple):
frame.set_values(stmt.results, result)

return frame.circuit


@func.dialect.register(key="emit.cirq")
class FuncEmit(MethodTable):

@impl(func.Function)
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
emit.run_ssacfg_region(frame, stmt.body, ())
return (frame.circuit,)

@impl(func.Invoke)
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
args = stmt.inputs
ret = stmt.result

with emit.new_frame(stmt.callee.code) as sub_frame:
sub_frame.qubit_index = frame.qubit_index
sub_frame.qubits = frame.qubits

region = stmt.callee.callable_region

# NOTE: need to set the block argument SSA values to the ones present in the frame
# FIXME: this feels wrong, there's probably a better way to do this
for block in region.blocks:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Roger-luo @weinbe58 do you know how I can properly get the BlockArguments and ReturnValues here? The SSA values are there, but the keys are different in the frame, so I resorted to some not so nice loops & checks.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now this will basically be the same as inlining the subroutine as you have it here.

I would either just error here and use inlining pass before running emit or think about how to lower a subroutine to a CircuitOperation

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we handle multiple blocks here? there is only one block? otherwise there will be some sort of control flows which I don't think Cirq supports anyways?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@weinbe58:

Right now this will basically be the same as inlining the subroutine as you have it here.

Kind of, but there's an important distinction: the result is wrapped as a CircuitOperation and added as a subcircuit.

or think about how to lower a subroutine to a CircuitOperation

Well, isn't this how you would do that? In the end, you need to emit another circuit from the subroutine somehow and then add it to the "parent" circuit. So, at some point we just need to step into the subroutine. That is sort of like inlining, but how else would you do it?

I don't see an alternative way to emit a CircuitOperation. We could probably clean up the way I'm passing in arguments and getting the return value, but I'm not sure how. Do you have an idea?

@Roger-luo:

why do we handle multiple blocks here?

Oversight on my part, I'm now throwing an error if there's more than a single block.

# NOTE: skip self in block args, so start at index 1
for block_arg, func_arg in zip(block.args[1:], args):
sub_frame.entries[block_arg] = frame.get(func_arg)

sub_circuit = emit.run_callable_region(
sub_frame, stmt.callee.code, region, ()
)
# emit.run_ssacfg_region(sub_frame, stmt.callee.callable_region, args=())

if not ret.type.is_subseteq(types.NoneType):
# NOTE: get the ResultValue of the return value and put it in the frame
# FIXME: this again feels _very_ wrong, there has to be a better way
ret_val = None
for val in sub_frame.entries.keys():
for use in val.uses:
if isinstance(use.stmt, func.Return):
ret_val = val
break

if ret_val is not None:
frame.entries[ret] = sub_frame.get(ret_val)

frame.circuit.append(
cirq.CircuitOperation(sub_circuit.freeze(), use_repetition_ids=False)
)
return ()


@op.dialect.register(key="emit.cirq")
class EmitCirqOpMethods(MethodTable):
@impl(op.stmts.X)
@impl(op.stmts.Y)
@impl(op.stmts.Z)
def pauli(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliOp
) -> tuple[cirq.Pauli]:
cirq_pauli = getattr(cirq, stmt.name.upper())
return (cirq_pauli,)

@impl(op.stmts.H)
def h(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.H):
return (cirq.H,)

@impl(op.stmts.S)
def s(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.S):
return (cirq.S,)

Check warning on line 121 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L121

Added line #L121 was not covered by tests

@impl(op.stmts.T)
def t(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.T):
return (cirq.T,)

Check warning on line 125 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L125

Added line #L125 was not covered by tests

@impl(op.stmts.P0)
def p0(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0):
raise NotImplementedError("TODO")

Check warning on line 129 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L129

Added line #L129 was not covered by tests

@impl(op.stmts.P1)
def p1(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P1):
raise NotImplementedError("TODO")

Check warning on line 133 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L133

Added line #L133 was not covered by tests

@impl(op.stmts.Sn)
def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
raise NotImplementedError("TODO")

Check warning on line 137 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L137

Added line #L137 was not covered by tests

@impl(op.stmts.Sp)
def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
raise NotImplementedError("TODO")

Check warning on line 141 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L141

Added line #L141 was not covered by tests

@impl(op.stmts.Identity)
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
return (cirq.IdentityGate(num_qubits=stmt.sites),)

@impl(op.stmts.Control)
def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
op: cirq.Gate = frame.get(stmt.op)
return (op.controlled(num_controls=stmt.n_controls),)

@impl(op.stmts.Kron)
def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
# lhs = frame.get(stmt.lhs)
# rhs = frame.get(stmt.rhs)
raise NotImplementedError("TODO")

Check warning on line 156 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L156

Added line #L156 was not covered by tests

@impl(op.stmts.Mult)
def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
# lhs = frame.get(stmt.lhs)
# rhs = frame.get(stmt.rhs)
raise NotImplementedError("TODO")

Check warning on line 162 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L162

Added line #L162 was not covered by tests

@impl(op.stmts.Adjoint)
def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
raise NotImplementedError("TODO")

Check warning on line 166 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L166

Added line #L166 was not covered by tests

@impl(op.stmts.Scale)
def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
raise NotImplementedError("TODO")

Check warning on line 170 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L170

Added line #L170 was not covered by tests

@impl(op.stmts.U3)
def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
raise NotImplementedError("TODO")

Check warning on line 174 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L174

Added line #L174 was not covered by tests

@impl(op.stmts.PhaseOp)
def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
raise NotImplementedError("TODO")

Check warning on line 178 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L178

Added line #L178 was not covered by tests

@impl(op.stmts.ShiftOp)
def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
raise NotImplementedError("TODO")

Check warning on line 182 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L182

Added line #L182 was not covered by tests

@impl(op.stmts.Reset)
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
raise NotImplementedError("TODO")

Check warning on line 186 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L186

Added line #L186 was not covered by tests

@impl(op.stmts.PauliString)
def pauli_string(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
):
raise NotImplementedError("TODO")

Check warning on line 192 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L192

Added line #L192 was not covered by tests


@qubit.dialect.register(key="emit.cirq")
class EmitCirqQubitMethods(MethodTable):
@impl(qubit.New)
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New):
n_qubits = frame.get(stmt.n_qubits)

if frame.qubits is not None:
cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)]
else:
cirq_qubits = [
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits)
]

frame.qubit_index += n_qubits
return (cirq_qubits,)

@impl(qubit.Apply)
def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply):
op = frame.get(stmt.operator)
qbits = frame.get(stmt.qubits)
operation = op(*qbits)
frame.circuit.append(operation)
return ()

@impl(qubit.Broadcast)
def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast):
op = frame.get(stmt.operator)
qbits = frame.get(stmt.qubits)

cirq_ops = [op(qbit) for qbit in qbits]
frame.circuit.append(cirq.Moment(cirq_ops))
return ()

@impl(qubit.MeasureQubit)
def measure_qubit(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit
):
qbit = frame.get(stmt.qubit)
frame.circuit.append(cirq.measure(qbit))
return ()

Check warning on line 234 in src/bloqade/squin/cirq/emit_circuit.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/emit_circuit.py#L232-L234

Added lines #L232 - L234 were not covered by tests

@impl(qubit.MeasureQubitList)
def measure_qubit_list(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList
):
qbits = frame.get(stmt.qubits)
frame.circuit.append(cirq.measure(qbits))
return ()
Loading