Skip to content

Commit 2167552

Browse files
Support for stimcirq gates and operations in cirq_google protos (#7101)
* Support for stimcirq gates and operations in cirq_google protos - Adds special casing for stimcirq gates and operations. - Note that this only supports gates and operations where the arguments can be serialized. - Serializing the stimcirq gates uses the json dictionary in order to gather arguments from the operations. - Tests will only be run if stimcirq is installed (manual use only) * Fix some tests. * Add requirements for stimcirq * fix coverage * format * Address comments. * Fix coverage. * format * Update cirq-google/cirq_google/serialization/circuit_serializer_test.py Co-authored-by: Pavol Juhas <[email protected]> * Move import to cached function --------- Co-authored-by: Pavol Juhas <[email protected]>
1 parent 34b9c81 commit 2167552

File tree

5 files changed

+77
-4
lines changed

5 files changed

+77
-4
lines changed

cirq-core/cirq/ops/gate_operation_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,9 @@ def all_subclasses(cls):
490490
gate_subclasses = {
491491
g
492492
for g in all_subclasses(cirq.Gate)
493-
if "cirq." in g.__module__ and "contrib" not in g.__module__ and "test" not in g.__module__
493+
if g.__module__.startswith("cirq.")
494+
and "contrib" not in g.__module__
495+
and "test" not in g.__module__
494496
}
495497

496498
test_module_spec = cirq.testing.json.spec_for("cirq.protocols")

cirq-google/cirq_google/serialization/circuit_serializer.py

+55-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Support for serializing and deserializing cirq_google.api.v2 protos."""
1616

1717
from typing import Any, Dict, List, Optional
18+
import functools
1819
import warnings
1920
import numpy as np
2021
import sympy
@@ -38,6 +39,9 @@
3839
# CircuitSerializer is the dedicated serializer for the v2.5 format.
3940
_SERIALIZER_NAME = 'v2_5'
4041

42+
# Package name for stimcirq
43+
_STIMCIRQ_MODULE = "stimcirq"
44+
4145

4246
class CircuitSerializer(serializer.Serializer):
4347
"""A class for serializing and deserializing programs and operations.
@@ -193,7 +197,6 @@ def _serialize_gate_op(
193197
ValueError: If the operation cannot be serialized.
194198
"""
195199
gate = op.gate
196-
197200
if isinstance(gate, InternalGate):
198201
arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate)
199202
elif isinstance(gate, cirq.XPowGate):
@@ -260,6 +263,30 @@ def _serialize_gate_op(
260263
arg_func_langs.float_arg_to_proto(
261264
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
262265
)
266+
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
267+
gate, "__module__", ""
268+
).startswith(_STIMCIRQ_MODULE):
269+
# Special handling for stimcirq objects, which can be both operations and gates.
270+
stimcirq_obj = (
271+
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
272+
)
273+
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
274+
# All stimcirq gates currently have _json_dict_defined
275+
msg.internalgate.name = type(stimcirq_obj).__name__
276+
msg.internalgate.module = _STIMCIRQ_MODULE
277+
if isinstance(stimcirq_obj, cirq.Gate):
278+
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
279+
else:
280+
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)
281+
282+
# Store json_dict objects in gate_args
283+
for k, v in stimcirq_obj._json_dict_().items():
284+
arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
285+
else:
286+
# New stimcirq op without a json dict has been introduced
287+
raise ValueError(
288+
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
289+
) # pragma: no cover
263290
else:
264291
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')
265292

@@ -670,7 +697,21 @@ def _deserialize_gate_op(
670697
raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!")
671698
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
672699
elif which_gate_type == 'internalgate':
673-
op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits)
700+
msg = operation_proto.internalgate
701+
if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers():
702+
# special handling for stimcirq
703+
# Use JSON resolver to instantiate the object
704+
kwargs = {}
705+
for k, v in msg.gate_args.items():
706+
arg = arg_func_langs.arg_from_proto(v)
707+
if arg is not None:
708+
kwargs[k] = arg
709+
op = _stimcirq_json_resolvers()[msg.name](**kwargs)
710+
if qubits:
711+
op = op(*qubits)
712+
else:
713+
# all other internal gates
714+
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
674715
elif which_gate_type == 'couplerpulsegate':
675716
gate = CouplerPulse(
676717
hold_time=cirq.Duration(
@@ -766,4 +807,16 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag):
766807
return None
767808

768809

810+
@functools.cache
811+
def _stimcirq_json_resolvers():
812+
"""Retrieves stimcirq JSON resolvers if stimcirq is installed.
813+
Returns an empty dict if not installed."""
814+
try:
815+
import stimcirq
816+
817+
return stimcirq.JSON_RESOLVERS_DICT
818+
except ModuleNotFoundError: # pragma: no cover
819+
return {} # pragma: no cover
820+
821+
769822
CIRCUIT_SERIALIZER = CircuitSerializer()

cirq-google/cirq_google/serialization/circuit_serializer_test.py

+15
Original file line numberDiff line numberDiff line change
@@ -1043,3 +1043,18 @@ def test_reset_gate_with_improper_argument():
10431043

10441044
with pytest.raises(ValueError, match="must be an integer"):
10451045
serializer.deserialize(circuit_proto)
1046+
1047+
1048+
def test_stimcirq_gates():
1049+
stimcirq = pytest.importorskip("stimcirq")
1050+
serializer = cg.CircuitSerializer()
1051+
q = cirq.q(1, 2)
1052+
q2 = cirq.q(2, 2)
1053+
c = cirq.Circuit(
1054+
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
1055+
cirq.Moment(cirq.measure(q, key="m")),
1056+
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
1057+
)
1058+
msg = serializer.serialize(c)
1059+
deserialized_circuit = serializer.deserialize(msg)
1060+
assert deserialized_circuit == c

dev_tools/conf/mypy.ini

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ follow_imports = silent
2121
ignore_missing_imports = true
2222

2323
# Non-Google
24-
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*]
24+
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*,stimcirq.*]
2525
follow_imports = silent
2626
ignore_missing_imports = true
2727

dev_tools/requirements/deps/dev-tools.txt

+3
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@ asv
1212

1313
# For verifying behavior of qasm output.
1414
qiskit-aer~=0.16.1
15+
16+
# For testing stimcirq compatibility (cirq-google)
17+
stimcirq

0 commit comments

Comments
 (0)