|
15 | 15 | """Support for serializing and deserializing cirq_google.api.v2 protos."""
|
16 | 16 |
|
17 | 17 | from typing import Any, Dict, List, Optional
|
| 18 | +import functools |
18 | 19 | import warnings
|
19 | 20 | import numpy as np
|
20 | 21 | import sympy
|
|
38 | 39 | # CircuitSerializer is the dedicated serializer for the v2.5 format.
|
39 | 40 | _SERIALIZER_NAME = 'v2_5'
|
40 | 41 |
|
| 42 | +# Package name for stimcirq |
| 43 | +_STIMCIRQ_MODULE = "stimcirq" |
| 44 | + |
41 | 45 |
|
42 | 46 | class CircuitSerializer(serializer.Serializer):
|
43 | 47 | """A class for serializing and deserializing programs and operations.
|
@@ -193,7 +197,6 @@ def _serialize_gate_op(
|
193 | 197 | ValueError: If the operation cannot be serialized.
|
194 | 198 | """
|
195 | 199 | gate = op.gate
|
196 |
| - |
197 | 200 | if isinstance(gate, InternalGate):
|
198 | 201 | arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate)
|
199 | 202 | elif isinstance(gate, cirq.XPowGate):
|
@@ -260,6 +263,30 @@ def _serialize_gate_op(
|
260 | 263 | arg_func_langs.float_arg_to_proto(
|
261 | 264 | gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
|
262 | 265 | )
|
| 266 | + elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr( |
| 267 | + gate, "__module__", "" |
| 268 | + ).startswith(_STIMCIRQ_MODULE): |
| 269 | + # Special handling for stimcirq objects, which can be both operations and gates. |
| 270 | + stimcirq_obj = ( |
| 271 | + op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate |
| 272 | + ) |
| 273 | + if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'): |
| 274 | + # All stimcirq gates currently have _json_dict_defined |
| 275 | + msg.internalgate.name = type(stimcirq_obj).__name__ |
| 276 | + msg.internalgate.module = _STIMCIRQ_MODULE |
| 277 | + if isinstance(stimcirq_obj, cirq.Gate): |
| 278 | + msg.internalgate.num_qubits = stimcirq_obj.num_qubits() |
| 279 | + else: |
| 280 | + msg.internalgate.num_qubits = len(stimcirq_obj.qubits) |
| 281 | + |
| 282 | + # Store json_dict objects in gate_args |
| 283 | + for k, v in stimcirq_obj._json_dict_().items(): |
| 284 | + arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k]) |
| 285 | + else: |
| 286 | + # New stimcirq op without a json dict has been introduced |
| 287 | + raise ValueError( |
| 288 | + f'Cannot serialize stimcirq {op!r}:{type(gate)}' |
| 289 | + ) # pragma: no cover |
263 | 290 | else:
|
264 | 291 | raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')
|
265 | 292 |
|
@@ -670,7 +697,21 @@ def _deserialize_gate_op(
|
670 | 697 | raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!")
|
671 | 698 | op = cirq.ResetChannel(dimension=dimensions)(*qubits)
|
672 | 699 | elif which_gate_type == 'internalgate':
|
673 |
| - op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits) |
| 700 | + msg = operation_proto.internalgate |
| 701 | + if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers(): |
| 702 | + # special handling for stimcirq |
| 703 | + # Use JSON resolver to instantiate the object |
| 704 | + kwargs = {} |
| 705 | + for k, v in msg.gate_args.items(): |
| 706 | + arg = arg_func_langs.arg_from_proto(v) |
| 707 | + if arg is not None: |
| 708 | + kwargs[k] = arg |
| 709 | + op = _stimcirq_json_resolvers()[msg.name](**kwargs) |
| 710 | + if qubits: |
| 711 | + op = op(*qubits) |
| 712 | + else: |
| 713 | + # all other internal gates |
| 714 | + op = arg_func_langs.internal_gate_from_proto(msg)(*qubits) |
674 | 715 | elif which_gate_type == 'couplerpulsegate':
|
675 | 716 | gate = CouplerPulse(
|
676 | 717 | hold_time=cirq.Duration(
|
@@ -766,4 +807,16 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag):
|
766 | 807 | return None
|
767 | 808 |
|
768 | 809 |
|
| 810 | +@functools.cache |
| 811 | +def _stimcirq_json_resolvers(): |
| 812 | + """Retrieves stimcirq JSON resolvers if stimcirq is installed. |
| 813 | + Returns an empty dict if not installed.""" |
| 814 | + try: |
| 815 | + import stimcirq |
| 816 | + |
| 817 | + return stimcirq.JSON_RESOLVERS_DICT |
| 818 | + except ModuleNotFoundError: # pragma: no cover |
| 819 | + return {} # pragma: no cover |
| 820 | + |
| 821 | + |
769 | 822 | CIRCUIT_SERIALIZER = CircuitSerializer()
|
0 commit comments