-
Notifications
You must be signed in to change notification settings - Fork 1
Emit a cirq.Circuit from a squin kernel #311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
david-pl
wants to merge
36
commits into
main
Choose a base branch
from
david/squin-to-cirq-emit
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 20 commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
9093dd2
Emit cirq.Circuit from squin kernel
david-pl 4d63742
Better typing in emit
david-pl e124955
Impl all pauli ops
david-pl d67c2d0
Impl for controls
david-pl 4bf7647
Inline pass for shorthand squin wrappers
david-pl 5d9469b
Fix typing in heuristic
david-pl dfc076c
Allow passing in custom list of qubits
david-pl 3c7bb64
Emit func invoke as subcircuit
david-pl c16e99c
Fix setting block arguments
david-pl 7ff58a3
Fix qubit index
david-pl 9f04d7d
Test nested kernels
david-pl 624ad93
Better error message
david-pl c309ad1
Update qubit index in frame when emitting invoke
david-pl 179d2cf
Properly deal with return values
david-pl 41e799d
Don't add all values to a sub_frame
david-pl 59cf51a
Remove inline pass that is no longer needed
david-pl 894cd53
Add a test when returning qubits from a nested kernel
david-pl 80724b9
Implement missing qubit methods and remove return value from kernel w…
david-pl 0c75266
Error when trying to emit a method with return value
david-pl ba64cde
Add operator methods with TODO
david-pl 7373a00
Restructure files
david-pl f5348d5
Start implementing runtime
david-pl bf9d88a
Projector and Sp, Sn runtime
david-pl ca73dab
Fix type
david-pl 2cf45fc
Adjoint runtime
david-pl d7e04cb
U3 runtime
david-pl e955fd0
Scale runtime
david-pl 78f1a3d
Phase op runtime
david-pl 1970b79
Shift op and reset runtime
david-pl b2f0e47
PauliString runtime
david-pl b0e4972
Split out runtime
david-pl a7a4d4c
Add docstrings and some examples
david-pl f0aeae6
Merge branch 'main' into david/squin-to-cirq-emit
david-pl 7e56a29
Require single block in emit_invoke
david-pl f2ffed1
Fix flaky test
david-pl 96955e7
Merge branch 'main' into david/squin-to-cirq-emit
david-pl File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
from typing import Sequence | ||
from dataclasses import field, dataclass | ||
|
||
import cirq | ||
from kirin import ir, types | ||
from kirin.emit import EmitABC, EmitFrame | ||
from kirin.interp import MethodTable, impl | ||
from kirin.dialects import func | ||
from typing_extensions import Self | ||
|
||
from .. import op, qubit, kernel | ||
|
||
|
||
@dataclass | ||
class EmitCirqFrame(EmitFrame): | ||
qubit_index: int = 0 | ||
qubits: Sequence[cirq.Qid] | None = None | ||
circuit: cirq.Circuit = field(default_factory=cirq.Circuit) | ||
|
||
|
||
def _default_kernel(): | ||
return kernel | ||
|
||
|
||
@dataclass | ||
class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]): | ||
keys = ["emit.cirq", "main"] | ||
dialects: ir.DialectGroup = field(default_factory=_default_kernel) | ||
void = cirq.Circuit() | ||
qubits: Sequence[cirq.Qid] | None = None | ||
|
||
def initialize(self) -> Self: | ||
return super().initialize() | ||
|
||
def initialize_frame( | ||
self, code: ir.Statement, *, has_parent_access: bool = False | ||
) -> EmitCirqFrame: | ||
return EmitCirqFrame( | ||
code, has_parent_access=has_parent_access, qubits=self.qubits | ||
) | ||
|
||
def run_method(self, method: ir.Method, args: tuple[cirq.Circuit, ...]): | ||
return self.run_callable(method.code, args) | ||
|
||
def emit_block(self, frame: EmitCirqFrame, block: ir.Block) -> cirq.Circuit: | ||
for stmt in block.stmts: | ||
result = self.eval_stmt(frame, stmt) | ||
if isinstance(result, tuple): | ||
frame.set_values(stmt.results, result) | ||
|
||
return frame.circuit | ||
|
||
|
||
@func.dialect.register(key="emit.cirq") | ||
class FuncEmit(MethodTable): | ||
|
||
@impl(func.Function) | ||
def emit_func(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Function): | ||
emit.run_ssacfg_region(frame, stmt.body, ()) | ||
return (frame.circuit,) | ||
|
||
@impl(func.Invoke) | ||
def emit_invoke(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: func.Invoke): | ||
args = stmt.inputs | ||
ret = stmt.result | ||
|
||
with emit.new_frame(stmt.callee.code) as sub_frame: | ||
sub_frame.qubit_index = frame.qubit_index | ||
sub_frame.qubits = frame.qubits | ||
|
||
region = stmt.callee.callable_region | ||
|
||
# NOTE: need to set the block argument SSA values to the ones present in the frame | ||
# FIXME: this feels wrong, there's probably a better way to do this | ||
for block in region.blocks: | ||
# NOTE: skip self in block args, so start at index 1 | ||
for block_arg, func_arg in zip(block.args[1:], args): | ||
sub_frame.entries[block_arg] = frame.get(func_arg) | ||
|
||
sub_circuit = emit.run_callable_region( | ||
sub_frame, stmt.callee.code, region, () | ||
) | ||
# emit.run_ssacfg_region(sub_frame, stmt.callee.callable_region, args=()) | ||
|
||
if not ret.type.is_subseteq(types.NoneType): | ||
# NOTE: get the ResultValue of the return value and put it in the frame | ||
# FIXME: this again feels _very_ wrong, there has to be a better way | ||
ret_val = None | ||
for val in sub_frame.entries.keys(): | ||
for use in val.uses: | ||
if isinstance(use.stmt, func.Return): | ||
ret_val = val | ||
break | ||
|
||
if ret_val is not None: | ||
frame.entries[ret] = sub_frame.get(ret_val) | ||
|
||
frame.circuit.append( | ||
cirq.CircuitOperation(sub_circuit.freeze(), use_repetition_ids=False) | ||
) | ||
return () | ||
|
||
|
||
@op.dialect.register(key="emit.cirq") | ||
class EmitCirqOpMethods(MethodTable): | ||
@impl(op.stmts.X) | ||
@impl(op.stmts.Y) | ||
@impl(op.stmts.Z) | ||
def pauli( | ||
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliOp | ||
) -> tuple[cirq.Pauli]: | ||
cirq_pauli = getattr(cirq, stmt.name.upper()) | ||
return (cirq_pauli,) | ||
|
||
@impl(op.stmts.H) | ||
def h(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.H): | ||
return (cirq.H,) | ||
|
||
@impl(op.stmts.S) | ||
def s(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.S): | ||
return (cirq.S,) | ||
|
||
@impl(op.stmts.T) | ||
def t(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.T): | ||
return (cirq.T,) | ||
|
||
@impl(op.stmts.P0) | ||
def p0(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P0): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.P1) | ||
def p1(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.P1): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Sn) | ||
def sn(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sn): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Sp) | ||
def sp(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Sp): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Identity) | ||
def identity(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Identity): | ||
return (cirq.IdentityGate(num_qubits=stmt.sites),) | ||
|
||
@impl(op.stmts.Control) | ||
def control(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Control): | ||
op: cirq.Gate = frame.get(stmt.op) | ||
return (op.controlled(num_controls=stmt.n_controls),) | ||
|
||
@impl(op.stmts.Kron) | ||
def kron(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Kron): | ||
# lhs = frame.get(stmt.lhs) | ||
# rhs = frame.get(stmt.rhs) | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Mult) | ||
def mult(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Mult): | ||
# lhs = frame.get(stmt.lhs) | ||
# rhs = frame.get(stmt.rhs) | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Adjoint) | ||
def adjoint(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Adjoint): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Scale) | ||
def scale(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Scale): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.U3) | ||
def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.U3): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.PhaseOp) | ||
def phaseop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PhaseOp): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.ShiftOp) | ||
def shiftop(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.ShiftOp): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.Reset) | ||
def reset(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.Reset): | ||
raise NotImplementedError("TODO") | ||
|
||
@impl(op.stmts.PauliString) | ||
def pauli_string( | ||
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: op.stmts.PauliString | ||
): | ||
raise NotImplementedError("TODO") | ||
|
||
|
||
@qubit.dialect.register(key="emit.cirq") | ||
class EmitCirqQubitMethods(MethodTable): | ||
@impl(qubit.New) | ||
def new(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.New): | ||
n_qubits = frame.get(stmt.n_qubits) | ||
|
||
if frame.qubits is not None: | ||
cirq_qubits = [frame.qubits[i + frame.qubit_index] for i in range(n_qubits)] | ||
else: | ||
cirq_qubits = [ | ||
cirq.LineQubit(i + frame.qubit_index) for i in range(n_qubits) | ||
] | ||
|
||
frame.qubit_index += n_qubits | ||
return (cirq_qubits,) | ||
|
||
@impl(qubit.Apply) | ||
def apply(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Apply): | ||
op = frame.get(stmt.operator) | ||
qbits = frame.get(stmt.qubits) | ||
operation = op(*qbits) | ||
frame.circuit.append(operation) | ||
return () | ||
|
||
@impl(qubit.Broadcast) | ||
def broadcast(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.Broadcast): | ||
op = frame.get(stmt.operator) | ||
qbits = frame.get(stmt.qubits) | ||
|
||
cirq_ops = [op(qbit) for qbit in qbits] | ||
frame.circuit.append(cirq.Moment(cirq_ops)) | ||
return () | ||
|
||
@impl(qubit.MeasureQubit) | ||
def measure_qubit( | ||
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubit | ||
): | ||
qbit = frame.get(stmt.qubit) | ||
frame.circuit.append(cirq.measure(qbit)) | ||
return () | ||
|
||
@impl(qubit.MeasureQubitList) | ||
def measure_qubit_list( | ||
self, emit: EmitCirq, frame: EmitCirqFrame, stmt: qubit.MeasureQubitList | ||
): | ||
qbits = frame.get(stmt.qubits) | ||
frame.circuit.append(cirq.measure(qbits)) | ||
return () |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Roger-luo @weinbe58 do you know how I can properly get the
BlockArguments
andReturnValues
here? The SSA values are there, but the keys are different in the frame, so I resorted to some not so nice loops & checks.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right now this will basically be the same as inlining the subroutine as you have it here.
I would either just error here and use inlining pass before running emit or think about how to lower a subroutine to a CircuitOperation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we handle multiple blocks here? there is only one block? otherwise there will be some sort of control flows which I don't think Cirq supports anyways?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@weinbe58:
Kind of, but there's an important distinction: the result is wrapped as a
CircuitOperation
and added as a subcircuit.Well, isn't this how you would do that? In the end, you need to emit another circuit from the subroutine somehow and then add it to the "parent" circuit. So, at some point we just need to step into the subroutine. That is sort of like inlining, but how else would you do it?
I don't see an alternative way to emit a
CircuitOperation
. We could probably clean up the way I'm passing in arguments and getting the return value, but I'm not sure how. Do you have an idea?@Roger-luo:
Oversight on my part, I'm now throwing an error if there's more than a single block.