Skip to content

Emit a cirq.Circuit from a squin kernel #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
9093dd2
Emit cirq.Circuit from squin kernel
david-pl Jun 3, 2025
4d63742
Better typing in emit
david-pl Jun 3, 2025
e124955
Impl all pauli ops
david-pl Jun 3, 2025
d67c2d0
Impl for controls
david-pl Jun 3, 2025
4bf7647
Inline pass for shorthand squin wrappers
david-pl Jun 4, 2025
5d9469b
Fix typing in heuristic
david-pl Jun 4, 2025
dfc076c
Allow passing in custom list of qubits
david-pl Jun 4, 2025
3c7bb64
Emit func invoke as subcircuit
david-pl Jun 4, 2025
c16e99c
Fix setting block arguments
david-pl Jun 4, 2025
7ff58a3
Fix qubit index
david-pl Jun 4, 2025
9f04d7d
Test nested kernels
david-pl Jun 4, 2025
624ad93
Better error message
david-pl Jun 4, 2025
c309ad1
Update qubit index in frame when emitting invoke
david-pl Jun 4, 2025
179d2cf
Properly deal with return values
david-pl Jun 4, 2025
41e799d
Don't add all values to a sub_frame
david-pl Jun 4, 2025
59cf51a
Remove inline pass that is no longer needed
david-pl Jun 4, 2025
894cd53
Add a test when returning qubits from a nested kernel
david-pl Jun 4, 2025
80724b9
Implement missing qubit methods and remove return value from kernel w…
david-pl Jun 4, 2025
0c75266
Error when trying to emit a method with return value
david-pl Jun 4, 2025
ba64cde
Add operator methods with TODO
david-pl Jun 4, 2025
7373a00
Restructure files
david-pl Jun 4, 2025
f5348d5
Start implementing runtime
david-pl Jun 4, 2025
bf9d88a
Projector and Sp, Sn runtime
david-pl Jun 5, 2025
ca73dab
Fix type
david-pl Jun 5, 2025
2cf45fc
Adjoint runtime
david-pl Jun 5, 2025
d7e04cb
U3 runtime
david-pl Jun 5, 2025
e955fd0
Scale runtime
david-pl Jun 5, 2025
78f1a3d
Phase op runtime
david-pl Jun 5, 2025
1970b79
Shift op and reset runtime
david-pl Jun 5, 2025
b2f0e47
PauliString runtime
david-pl Jun 5, 2025
b0e4972
Split out runtime
david-pl Jun 5, 2025
a7a4d4c
Add docstrings and some examples
david-pl Jun 5, 2025
f0aeae6
Merge branch 'main' into david/squin-to-cirq-emit
david-pl Jun 10, 2025
7e56a29
Require single block in emit_invoke
david-pl Jun 10, 2025
f2ffed1
Fix flaky test
david-pl Jun 10, 2025
96955e7
Merge branch 'main' into david/squin-to-cirq-emit
david-pl Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 112 additions & 1 deletion src/bloqade/squin/cirq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing import Any
from typing import Any, Sequence

import cirq
from kirin import ir, types
from kirin.emit import EmitError
from kirin.dialects import func

from . import lowering as lowering
from .. import kernel

# NOTE: just to register methods
from .emit import op as op, qubit as qubit
from .lowering import Squin
from .emit.emit_circuit import EmitCirq


def load_circuit(
Expand Down Expand Up @@ -87,3 +92,109 @@
dialects=dialects,
code=code,
)


def emit_circuit(
mt: ir.Method,
qubits: Sequence[cirq.Qid] | None = None,
) -> cirq.Circuit:
"""Converts a squin.kernel method to a cirq.Circuit object.

Args:
mt (ir.Method): The kernel method from which to construct the circuit.

Keyword Args:
qubits (Sequence[cirq.Qid] | None):
A list of qubits to use as the qubits in the circuit. Defaults to None.
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
statement in the order they appear inside the kernel.
**Note**: If a list of qubits is provided, make sure that there is a sufficient
number of qubits for the resulting circuit.

## Examples:

Here's a very basic example:

```python
from bloqade import squin

@squin.kernel
def main():
q = squin.qubit.new(2)
h = squin.op.h()
squin.qubit.apply(h, q[0])
cx = squin.op.cx()
squin.qubit.apply(cx, q)

circuit = squin.cirq.emit_circuit(main)

print(circuit)
```

You can also compose multiple kernels. Those are emitted as subcircuits within the "main" circuit.
Subkernels can accept arguments and return a value.

```python
from bloqade import squin
from kirin.dialects import ilist
from typing import Literal
import cirq

@squin.kernel
def entangle(q: ilist.IList[squin.qubit.Qubit, Literal[2]]):
h = squin.op.h()
squin.qubit.apply(h, q[0])
cx = squin.op.cx()
squin.qubit.apply(cx, q)
return cx

@squin.kernel
def main():
q = squin.qubit.new(2)
cx = entangle(q)
q2 = squin.qubit.new(3)
squin.qubit.apply(cx, [q[1], q2[2]])


# custom list of qubits on grid
qubits = [cirq.GridQubit(i, i+1) for i in range(5)]

circuit = squin.cirq.emit_circuit(main, qubits=qubits)
print(circuit)

```

We also passed in a custom list of qubits above. This allows you to provide a custom geometry
and manipulate the qubits in other circuits directly written in cirq as well.
"""

if isinstance(mt.code, func.Function) and not mt.code.signature.output.is_subseteq(
types.NoneType
):
raise EmitError(
"The method you are trying to convert to a circuit has a return value, but returning from a circuit is not supported."
)

emitter = EmitCirq(qubits=qubits)
return emitter.run(mt, args=())


def dump_circuit(mt: ir.Method, qubits: Sequence[cirq.Qid] | None = None, **kwargs):
"""Converts a squin.kernel method to a cirq.Circuit object and dumps it as JSON.

This just runs `emit_circuit` and calls the `cirq.to_json` function to emit a JSON.

Args:
mt (ir.Method): The kernel method from which to construct the circuit.

Keyword Args:
qubits (Sequence[cirq.Qid] | None):
A list of qubits to use as the qubits in the circuit. Defaults to None.
If this is None, then `cirq.LineQubit`s are inserted for every `squin.qubit.new`
statement in the order they appear inside the kernel.
**Note**: If a list of qubits is provided, make sure that there is a sufficient
number of qubits for the resulting circuit.

"""
circuit = emit_circuit(mt, qubits=qubits)
return cirq.to_json(circuit, **kwargs)

Check warning on line 200 in src/bloqade/squin/cirq/__init__.py

View check run for this annotation

Codecov / codecov/patch

src/bloqade/squin/cirq/__init__.py#L199-L200

Added lines #L199 - L200 were not covered by tests
101 changes: 101 additions & 0 deletions src/bloqade/squin/cirq/emit/emit_circuit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from typing import Sequence
from dataclasses import field, dataclass

import cirq
from kirin import ir, types
from kirin.emit import EmitABC, EmitFrame
from kirin.interp import MethodTable, impl
from kirin.dialects import func
from typing_extensions import Self

from ... import kernel


@dataclass
class EmitCirqFrame(EmitFrame):
qubit_index: int = 0
qubits: Sequence[cirq.Qid] | None = None
circuit: cirq.Circuit = field(default_factory=cirq.Circuit)


def _default_kernel():
return kernel


@dataclass
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
keys = ["emit.cirq", "main"]
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
void = cirq.Circuit()
qubits: Sequence[cirq.Qid] | None = None

def initialize(self) -> Self:
return super().initialize()

def initialize_frame(
self, code: ir.Statement, *, has_parent_access: bool = False
) -> EmitCirqFrame:
return EmitCirqFrame(
code, has_parent_access=has_parent_access, qubits=self.qubits
)

def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]):
return self.run_callable(method.code, args)

def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit:
for stmt in block.stmts:
result = self.eval_stmt(frame, stmt)
if isinstance(result, tuple):
frame.set_values(stmt.results, result)

return frame.circuit


@func.dialect.register(key="emit.cirq")
class FuncEmit(MethodTable):

@impl(func.Function)
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function):
emit.run_ssacfg_region(frame, stmt.body, ())
return (frame.circuit,)

@impl(func.Invoke)
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke):
args = stmt.inputs
ret = stmt.result

with emit.new_frame(stmt.callee.code) as sub_frame:
sub_frame.qubit_index = frame.qubit_index
sub_frame.qubits = frame.qubits

region = stmt.callee.callable_region

# NOTE: need to set the block argument SSA values to the ones present in the frame
# FIXME: this feels wrong, there's probably a better way to do this
for block in region.blocks:
# NOTE: skip self in block args, so start at index 1
for block_arg, func_arg in zip(block.args[1:], args):
sub_frame.entries[block_arg] = frame.get(func_arg)

sub_circuit = emit.run_callable_region(
sub_frame, stmt.callee.code, region, ()
)
# emit.run_ssacfg_region(sub_frame, stmt.callee.callable_region, args=())

if not ret.type.is_subseteq(types.NoneType):
# NOTE: get the ResultValue of the return value and put it in the frame
# FIXME: this again feels _very_ wrong, there has to be a better way
ret_val = None
for val in sub_frame.entries.keys():
for use in val.uses:
if isinstance(use.stmt, func.Return):
ret_val = val
break

if ret_val is not None:
frame.entries[ret] = sub_frame.get(ret_val)

frame.circuit.append(
cirq.CircuitOperation(sub_circuit.freeze(), use_repetition_ids=False)
)
return ()
125 changes: 125 additions & 0 deletions src/bloqade/squin/cirq/emit/op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import math

import cirq
import numpy as np
from kirin.interp import MethodTable, impl

from ... import op
from .runtime import (
SnRuntime,
SpRuntime,
U3Runtime,
KronRuntime,
MultRuntime,
ScaleRuntime,
AdjointRuntime,
ControlRuntime,
UnitaryRuntime,
HermitianRuntime,
ProjectorRuntime,
OperatorRuntimeABC,
PauliStringRuntime,
)
from .emit_circuit import EmitCirq, EmitCirqFrame


@op.dialect.register(key="emit.cirq")
class EmitCirqOpMethods(MethodTable):

@impl(op.stmts.X)
@impl(op.stmts.Y)
@impl(op.stmts.Z)
@impl(op.stmts.H)
def hermitian(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
):
cirq_op = getattr(cirq, stmt.name.upper())
return (HermitianRuntime(cirq_op),)

@impl(op.stmts.S)
@impl(op.stmts.T)
def unitary(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ConstantUnitary
):
cirq_op = getattr(cirq, stmt.name.upper())
return (UnitaryRuntime(cirq_op),)

@impl(op.stmts.P0)
@impl(op.stmts.P1)
def projector(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0 | op.stmts.P1
):
return (ProjectorRuntime(isinstance(stmt, op.stmts.P1)),)

@impl(op.stmts.Sn)
def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn):
return (SnRuntime(),)

@impl(op.stmts.Sp)
def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp):
return (SpRuntime(),)

@impl(op.stmts.Identity)
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity):
op = HermitianRuntime(cirq.IdentityGate(num_qubits=stmt.sites))
return (op,)

@impl(op.stmts.Control)
def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control):
op: OperatorRuntimeABC = frame.get(stmt.op)
return (ControlRuntime(op, stmt.n_controls),)

@impl(op.stmts.Kron)
def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
op = KronRuntime(lhs, rhs)
return (op,)

@impl(op.stmts.Mult)
def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult):
lhs = frame.get(stmt.lhs)
rhs = frame.get(stmt.rhs)
op = MultRuntime(lhs, rhs)
return (op,)

@impl(op.stmts.Adjoint)
def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint):
op_ = frame.get(stmt.op)
return (AdjointRuntime(op_),)

@impl(op.stmts.Scale)
def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale):
op_ = frame.get(stmt.op)
factor = frame.get(stmt.factor)
return (ScaleRuntime(operator=op_, factor=factor),)

@impl(op.stmts.U3)
def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3):
theta = frame.get(stmt.theta)
phi = frame.get(stmt.phi)
lam = frame.get(stmt.lam)
return (U3Runtime(theta=theta, phi=phi, lam=lam),)

@impl(op.stmts.PhaseOp)
def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp):
theta = frame.get(stmt.theta)
op_ = HermitianRuntime(cirq.IdentityGate(num_qubits=1))
return (ScaleRuntime(operator=op_, factor=np.exp(1j * theta)),)

@impl(op.stmts.ShiftOp)
def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp):
theta = frame.get(stmt.theta)

# NOTE: ShiftOp(theta) == U3(pi, theta, 0)
return (U3Runtime(math.pi, theta, 0),)

@impl(op.stmts.Reset)
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset):
return (HermitianRuntime(cirq.ResetChannel()),)

@impl(op.stmts.PauliString)
def pauli_string(
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString
):
return (PauliStringRuntime(stmt.string),)
Loading