Skip to content

Commit d52c597

Browse files
authored
Merge pull request #212 from peachnuts/add_reset
Add reset
2 parents 2d2778d + 2bf45ad commit d52c597

File tree

12 files changed

+84294
-37
lines changed

12 files changed

+84294
-37
lines changed

bqskit/ir/gates/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@
112112
113113
CircuitGate
114114
MeasurementPlaceholder
115+
Reset
115116
BarrierPlaceholder
116117
117118
.. rubric:: Gate Base Classes
@@ -146,11 +147,15 @@
146147
from bqskit.ir.gates.qubitgate import QubitGate
147148
from bqskit.ir.gates.quditgate import QuditGate
148149
from bqskit.ir.gates.qutritgate import QutritGate
150+
from bqskit.ir.gates.reset import Reset
149151

150152
__all__ = composed_all + constant_all + parameterized_all
151153
__all__ += ['ComposedGate', 'ConstantGate']
152154
__all__ += ['QubitGate', 'QutritGate', 'QuditGate']
153-
__all__ += ['CircuitGate', 'MeasurementPlaceholder', 'BarrierPlaceholder']
155+
__all__ += [
156+
'CircuitGate', 'MeasurementPlaceholder',
157+
'Reset', 'BarrierPlaceholder',
158+
]
154159
__all__ += ['GeneralGate']
155160

156161
# TODO: Implement the rest of the gates in:

bqskit/ir/gates/reset.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""This module implements the Reset class."""
2+
from __future__ import annotations
3+
4+
from bqskit.ir.gates.constantgate import ConstantGate
5+
from bqskit.qis.unitary.unitary import RealVector
6+
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
7+
8+
9+
class Reset(ConstantGate):
10+
"""Pseudogate to reset/initialize the qudit to |0>."""
11+
12+
def __init__(self, radix: int = 2) -> None:
13+
"""
14+
Construct a Reset.
15+
16+
Args:
17+
radix (int): the dimension of the qudit. (Default: 2)
18+
"""
19+
self._num_qudits = 1
20+
self._qasm_name = 'reset'
21+
self._radixes = tuple([radix])
22+
self._num_params = 0
23+
24+
def get_unitary(self, params: RealVector = []) -> UnitaryMatrix:
25+
raise RuntimeError('Cannot compute unitary for a reset.')

bqskit/ir/lang/qasm2/visitor.py

+30-29
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
from bqskit.ir.gates.parameterized.u1q import U1qGate
6969
from bqskit.ir.gates.parameterized.u2 import U2Gate
7070
from bqskit.ir.gates.parameterized.u3 import U3Gate
71+
from bqskit.ir.gates.reset import Reset
7172
from bqskit.ir.lang.language import LangException
7273
from bqskit.ir.lang.qasm2.parser import parse
7374
from bqskit.ir.location import CircuitLocation
@@ -169,7 +170,6 @@ def __init__(self) -> None:
169170
self.classical_regs: list[ClassicalReg] = []
170171
self.gate_def_parsing_obj: Any = None
171172
self.custom_gate_defs: dict[str, CustomGateDef] = {}
172-
self.measurements: dict[int, tuple[str, int]] = {}
173173
self.fill_gate_defs()
174174

175175
def get_circuit(self) -> Circuit:
@@ -180,12 +180,6 @@ def get_circuit(self) -> Circuit:
180180
circuit = Circuit(num_qubits)
181181
circuit.extend(self.op_list)
182182

183-
# Add measurements
184-
if len(self.measurements) > 0:
185-
cregs = cast(List[Tuple[str, int]], self.classical_regs)
186-
mph = MeasurementPlaceholder(cregs, self.measurements)
187-
circuit.append_gate(mph, list(self.measurements.keys()))
188-
189183
return circuit
190184

191185
def fill_gate_defs(self) -> None:
@@ -297,13 +291,6 @@ def gate(self, tree: lark.Tree) -> None:
297291
qlist = tree.children[-1]
298292
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))
299293

300-
if any(q in self.measurements for q in location):
301-
raise LangException(
302-
'BQSKit currently does not support mid-circuit measurements.'
303-
' Unable to apply a gate on the same qubit where a measurement'
304-
' has been previously made.',
305-
)
306-
307294
# Parse gate object
308295
gate_name = str(tree.children[0])
309296
if gate_name in self.gate_defs:
@@ -591,10 +578,16 @@ def creg(self, tree: lark.Tree) -> None:
591578

592579
def measure(self, tree: lark.Tree) -> None:
593580
"""Measure statement node visitor."""
581+
params: list[float] = []
582+
measurements: dict[int, tuple[str, int]] = {}
594583
qubit_childs = tree.children[0].children
595584
class_childs = tree.children[1].children
596585
qubit_reg_name = str(qubit_childs[0])
597586
class_reg_name = str(class_childs[0])
587+
cregs = cast(List[Tuple[str, int]], self.classical_regs)
588+
qlist = tree.children[0]
589+
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))
590+
598591
if not any(r.name == qubit_reg_name for r in self.qubit_regs):
599592
raise LangException(
600593
f'Measuring undefined qubit register: {qubit_reg_name}',
@@ -605,7 +598,7 @@ def measure(self, tree: lark.Tree) -> None:
605598
f'Measuring undefined classical register: {class_reg_name}',
606599
)
607600

608-
if len(qubit_childs) == 1 and len(class_childs) == 1:
601+
if len(qubit_childs) == 1 and len(class_childs) == 1: # for measure all
609602
for name, size in self.qubit_regs:
610603
if qubit_reg_name == name:
611604
qubit_size = size
@@ -625,34 +618,42 @@ def measure(self, tree: lark.Tree) -> None:
625618
if name == qubit_reg_name:
626619
break
627620
outer_idx += size
628-
629621
for i in range(qubit_size):
630-
self.measurements[outer_idx + i] = (class_reg_name, i)
622+
measurements[outer_idx + i] = (class_reg_name, i)
623+
mph = MeasurementPlaceholder(cregs, measurements)
631624

632625
elif len(qubit_childs) == 2 and len(class_childs) == 2:
626+
# measure qubits to clbits
633627
qubit_index = int(qubit_childs[1])
634628
class_index = int(class_childs[1])
635-
636-
# Convert qubit_index to global index
637-
outer_idx = 0
638-
for name, size in self.qubit_regs:
639-
if name == qubit_reg_name:
640-
qubit_index = outer_idx + qubit_index
641-
break
642-
outer_idx += size
643-
644-
self.measurements[qubit_index] = (class_reg_name, class_index)
629+
measurements[qubit_index] = (class_reg_name, class_index)
630+
mph = MeasurementPlaceholder(cregs, measurements)
645631

646632
else:
647633
raise LangException(
648634
'Invalid measurement: either a single qubit is being measured '
649635
'to a full classical register or a qubit register is being '
650636
'measured to a single classical bit.',
651637
)
638+
op = Operation(mph, location, params)
639+
self.op_list.append(op)
652640

653641
def reset(self, tree: lark.Tree) -> None:
654-
"""Reset statement node visitor."""
655-
raise LangException('BQSKit currently does not support resets.')
642+
"""Reset node visitor."""
643+
params: list[float] = []
644+
qlist = tree.children[-1]
645+
if len(qlist.children) == 2:
646+
location = CircuitLocation(self.convert_qubit_ids_to_indices(qlist))
647+
op = Operation(Reset(), location, params)
648+
self.op_list.append(op)
649+
else:
650+
locations = [
651+
CircuitLocation(i)
652+
for i in range(self.qubit_regs[0][1])
653+
]
654+
for location in locations:
655+
op = Operation(Reset(), location, params)
656+
self.op_list.append(op)
656657

657658
def convert_qubit_ids_to_indices(self, qlist: lark.Tree) -> list[int]:
658659
if qlist.data == 'anylist':

bqskit/passes/partitioning/quick.py

+23-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from bqskit.compiler.basepass import BasePass
99
from bqskit.compiler.passdata import PassData
1010
from bqskit.ir.circuit import Circuit
11+
from bqskit.ir.gates import MeasurementPlaceholder
12+
from bqskit.ir.gates import Reset
1113
from bqskit.ir.gates.barrier import BarrierPlaceholder
1214
from bqskit.ir.gates.circuitgate import CircuitGate
1315
from bqskit.ir.location import CircuitLocation
@@ -117,8 +119,15 @@ def process_pending_bins() -> None:
117119
merging = False
118120
for p in partitioned_circuit.rear:
119121
op = partitioned_circuit[p]
120-
if isinstance(op.gate, BarrierPlaceholder):
121-
# Don't merge through barriers
122+
if isinstance(
123+
op.gate, (
124+
BarrierPlaceholder,
125+
MeasurementPlaceholder,
126+
Reset,
127+
),
128+
):
129+
# Don't merge through barriers,
130+
# measurement, or reset
122131
continue
123132
qudits = list(op.location)
124133

@@ -179,7 +188,13 @@ def process_pending_bins() -> None:
179188
})
180189

181190
# Barriers close all overlapping bins
182-
if isinstance(op.gate, BarrierPlaceholder):
191+
if isinstance(
192+
op.gate, (
193+
BarrierPlaceholder,
194+
MeasurementPlaceholder,
195+
Reset,
196+
),
197+
):
183198
for bin in overlapping_bins:
184199
if close_bin_qudits(bin, location, cycle):
185200
num_closed += 1
@@ -349,7 +364,11 @@ def can_accommodate(self, loc: CircuitLocation, block_size: int) -> bool:
349364

350365

351366
class BarrierBin(Bin):
352-
"""A special bin made to mark and preserve barrier location."""
367+
"""
368+
A special bin made to mark and preserve barrier location.
369+
370+
For simplicity, the rest and measurement are treated as barrier as well.
371+
"""
353372

354373
def __init__(
355374
self,

bqskit/passes/partitioning/single.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from bqskit.compiler.basepass import BasePass
55
from bqskit.compiler.passdata import PassData
66
from bqskit.ir.circuit import Circuit
7+
from bqskit.ir.gates import MeasurementPlaceholder
8+
from bqskit.ir.gates import Reset
79
from bqskit.ir.gates.barrier import BarrierPlaceholder
810
from bqskit.ir.region import CircuitRegion
911

@@ -31,7 +33,13 @@ async def run(self, circuit: Circuit, data: PassData) -> None:
3133
op = circuit[c, q]
3234
if (
3335
op.num_qudits == 1
34-
and not isinstance(op.gate, BarrierPlaceholder)
36+
and not isinstance(
37+
op.gate, (
38+
BarrierPlaceholder,
39+
MeasurementPlaceholder,
40+
Reset,
41+
),
42+
)
3543
):
3644
if region_start is None:
3745
region_start = c

tests/ir/lang/test_qasm_decode.py

+86
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from bqskit.ir.gates.parameterized.u1q import U1qGate
1616
from bqskit.ir.gates.parameterized.u2 import U2Gate
1717
from bqskit.ir.gates.parameterized.u3 import U3Gate
18+
from bqskit.ir.gates.reset import Reset
1819
from bqskit.ir.lang.language import LangException
1920
from bqskit.ir.lang.qasm2 import OPENQASM2Language
2021
from bqskit.qis.unitary.unitarymatrix import UnitaryMatrix
@@ -296,6 +297,43 @@ def test_include_simple(self) -> None:
296297
assert circuit.get_unitary().get_distance_from(gate_unitary) < 1e-7
297298

298299

300+
class TestReset:
301+
def test_reset_single_qubit(self) -> None:
302+
input = """
303+
OPENQASM 2.0;
304+
qreg q[1];
305+
reset q[0];
306+
"""
307+
circuit = OPENQASM2Language().decode(input)
308+
expected = Reset()
309+
assert circuit[0, 0].gate == expected
310+
311+
def test_reset_register(self) -> None:
312+
input = """
313+
OPENQASM 2.0;
314+
qreg q[2];
315+
reset q;
316+
"""
317+
circuit = OPENQASM2Language().decode(input)
318+
expected = Reset()
319+
assert circuit[0, 0].gate == expected
320+
assert circuit[0, 1].gate == expected
321+
322+
def test_mid_reset(self) -> None:
323+
input = """
324+
OPENQASM 2.0;
325+
qreg q[1];
326+
u1(0.1) q[0];
327+
reset q[0];
328+
u1(0.1) q[0];
329+
"""
330+
circuit = OPENQASM2Language().decode(input)
331+
reset = Reset()
332+
assert circuit[0, 0].gate == U1Gate()
333+
assert circuit[1, 0].gate == reset
334+
assert circuit[2, 0].gate == U1Gate()
335+
336+
299337
class TestMeasure:
300338
def test_measure_single_bit(self) -> None:
301339
input = """
@@ -308,6 +346,54 @@ def test_measure_single_bit(self) -> None:
308346
expected = MeasurementPlaceholder([('c', 1)], {0: ('c', 0)})
309347
assert circuit[0, 0].gate == expected
310348

349+
def test_mid_measure_single_bit(self) -> None:
350+
input = """
351+
OPENQASM 2.0;
352+
qreg q[1];
353+
creg c[1];
354+
u1(0.1) q[0];
355+
measure q[0] -> c[0];
356+
u1(0.1) q[0];
357+
"""
358+
circuit = OPENQASM2Language().decode(input)
359+
measurements = {0: ('c', 0)}
360+
measure = MeasurementPlaceholder([('c', 1)], measurements)
361+
assert circuit[0, 0].gate == U1Gate()
362+
assert circuit[1, 0].gate == measure
363+
assert circuit[2, 0].gate == U1Gate()
364+
365+
def test_mid_measure_register_1(self) -> None:
366+
input = """
367+
OPENQASM 2.0;
368+
qreg q[1];
369+
creg c[1];
370+
u1(0.1) q[0];
371+
measure q -> c;
372+
u1(0.1) q[0];
373+
"""
374+
circuit = OPENQASM2Language().decode(input)
375+
measurements = {0: ('c', 0)}
376+
measure = MeasurementPlaceholder([('c', 1)], measurements)
377+
assert circuit[0, 0].gate == U1Gate()
378+
assert circuit[1, 0].gate == measure
379+
assert circuit[2, 0].gate == U1Gate()
380+
381+
def test_mid_measure_register_2(self) -> None:
382+
input = """
383+
OPENQASM 2.0;
384+
qreg q[2];
385+
creg c[2];
386+
u1(0.1) q[0];
387+
measure q -> c;
388+
u1(0.1) q[0];
389+
"""
390+
circuit = OPENQASM2Language().decode(input)
391+
measurements = {0: ('c', 0), 1: ('c', 1)}
392+
measure = MeasurementPlaceholder([('c', 2)], measurements)
393+
assert circuit[0, 0].gate == U1Gate()
394+
assert circuit[1, 0].gate == measure
395+
assert circuit[2, 0].gate == U1Gate()
396+
311397
def test_measure_register_1(self) -> None:
312398
input = """
313399
OPENQASM 2.0;

tests/ir/lang/test_qasm_encode.py

+14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from bqskit.ir.circuit import Circuit
44
from bqskit.ir.gates import CNOTGate
5+
from bqskit.ir.gates import Reset
56
from bqskit.ir.gates import U3Gate
67
from bqskit.ir.lang.qasm2 import OPENQASM2Language
78

@@ -41,3 +42,16 @@ def test_nested_circuitgate(self) -> None:
4142
qasm = OPENQASM2Language().encode(circuit)
4243
parsed_circuit = OPENQASM2Language().decode(qasm)
4344
assert parsed_circuit.get_unitary().get_distance_from(in_utry) < 1e-7
45+
46+
def test_reset(self) -> None:
47+
circuit = Circuit(1)
48+
circuit.append_gate(Reset(), 0)
49+
50+
qasm = OPENQASM2Language().encode(circuit)
51+
expected = (
52+
'OPENQASM 2.0;\n'
53+
'include "qelib1.inc";\n'
54+
'qreg q[1];\n'
55+
'reset q[0];\n'
56+
)
57+
assert qasm == expected

0 commit comments

Comments
 (0)