Skip to content

Commit 45293ae

Browse files
authored
Merge pull request #269 from BQSKit/register_workflow
Register `Workflow`s
2 parents 52fae4b + 1f94658 commit 45293ae

File tree

6 files changed

+213
-0
lines changed

6 files changed

+213
-0
lines changed

bqskit/compiler/compile.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from bqskit.compiler.compiler import Compiler
1515
from bqskit.compiler.machine import MachineModel
1616
from bqskit.compiler.passdata import PassData
17+
from bqskit.compiler.registry import _compile_registry
1718
from bqskit.compiler.workflow import Workflow
1819
from bqskit.compiler.workflow import WorkflowLike
1920
from bqskit.ir.circuit import Circuit
@@ -668,6 +669,12 @@ def build_workflow(
668669
if model is None:
669670
model = MachineModel(input.num_qudits, radixes=input.radixes)
670671

672+
# Use a registered workflow if model is found in the registry for a given
673+
# optimization_level
674+
if model in _compile_registry:
675+
if optimization_level in _compile_registry[model]:
676+
return _compile_registry[model][optimization_level]
677+
671678
if isinstance(input, Circuit):
672679
if input.num_qudits > max_synthesis_size:
673680
if any(

bqskit/compiler/gateset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,5 +231,9 @@ def __repr__(self) -> str:
231231
"""Detailed representation of the GateSet."""
232232
return self._gates.__repr__().replace('frozenset', 'GateSet')
233233

234+
def __hash__(self) -> int:
235+
"""Hash of the GateSet."""
236+
return hash(tuple(sorted([g.name for g in self._gates])))
237+
234238

235239
GateSetLike = Union[GateSet, Iterable[Gate], Gate]

bqskit/compiler/registry.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
"""Register MachineModel specific default workflows."""
2+
from __future__ import annotations
3+
4+
import warnings
5+
6+
from bqskit.compiler.machine import MachineModel
7+
from bqskit.compiler.workflow import Workflow
8+
from bqskit.compiler.workflow import WorkflowLike
9+
10+
11+
_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}
12+
13+
14+
def register_workflow(
15+
key: MachineModel,
16+
workflow: WorkflowLike,
17+
optimization_level: int,
18+
) -> None:
19+
"""
20+
Register a workflow for a given MachineModel.
21+
22+
The _compile_registry enables MachineModel specific workflows to be
23+
registered for use in the `bqskit.compile` method. _compile_registry maps
24+
MachineModels a dictionary of Workflows which are indexed by optimization
25+
level. This object should not be accessed directly by the user, but
26+
instead through the `register_workflow` function.
27+
28+
Args:
29+
key (MachineModel): A MachineModel to register the workflow under.
30+
If a circuit is compiled targeting this machine or gate set, the
31+
registered workflow will be used.
32+
33+
workflow (list[BasePass]): The workflow or list of passes that will
34+
be executed if the MachineModel in a call to `compile` matches
35+
`key`. If `key` is already registered, a warning will be logged.
36+
37+
optimization_level ptional[int): The optimization level with which
38+
to register the workflow. If no level is provided, the Workflow
39+
will be registered as level 1.
40+
41+
Example:
42+
model_t = SpecificMachineModel(num_qudits, radixes)
43+
workflow = [QuickPartitioner(3), NewFangledOptimization()]
44+
register_workflow(model_t, workflow, level)
45+
...
46+
new_circuit = compile(circuit, model_t, optimization_level=level)
47+
48+
Raises:
49+
Warning: If a workflow for a given optimization_level is overwritten.
50+
"""
51+
workflow = Workflow(workflow)
52+
53+
global _compile_registry
54+
new_workflow = {optimization_level: workflow}
55+
if key in _compile_registry:
56+
if optimization_level in _compile_registry[key]:
57+
m = f'Overwritting workflow for {key} at level '
58+
m += f'{optimization_level}. If multiple Namespace packages are '
59+
m += 'installed, ensure that their __init__.py files do not '
60+
m += 'attempt to overwrite the same default Workflows.'
61+
warnings.warn(m)
62+
_compile_registry[key].update(new_workflow)
63+
else:
64+
_compile_registry[key] = new_workflow

bqskit/compiler/workflow.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def name(self) -> str:
9090
"""The name of the pass."""
9191
return self._name or self.__class__.__name__
9292

93+
@staticmethod
94+
def is_workflow(workflow: WorkflowLike) -> bool:
95+
if not is_iterable(workflow):
96+
return isinstance(workflow, BasePass)
97+
return all(isinstance(p, BasePass) for p in workflow)
98+
9399
def __str__(self) -> str:
94100
name_seq = f'Workflow: {self.name}\n\t'
95101
pass_strs = [

tests/compiler/test_gateset.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -522,3 +522,16 @@ def test_gate_set_repr() -> None:
522522
repr(gate_set) == 'GateSet({CNOTGate, U3Gate})'
523523
or repr(gate_set) == 'GateSet({U3Gate, CNOTGate})'
524524
)
525+
526+
527+
def test_gate_set_hash() -> None:
528+
gate_set_1 = GateSet({CNOTGate(), U3Gate()})
529+
gate_set_2 = GateSet({U3Gate(), CNOTGate()})
530+
gate_set_3 = GateSet({U3Gate(), CNOTGate(), RZGate()})
531+
532+
h1 = hash(gate_set_1)
533+
h2 = hash(gate_set_2)
534+
h3 = hash(gate_set_3)
535+
536+
assert h1 == h2
537+
assert h1 != h3

tests/compiler/test_registry.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""This file tests the register_workflow function."""
2+
from __future__ import annotations
3+
4+
from itertools import combinations
5+
from random import choice
6+
7+
import pytest
8+
from numpy import allclose
9+
10+
from bqskit.compiler.compile import compile
11+
from bqskit.compiler.machine import MachineModel
12+
from bqskit.compiler.registry import _compile_registry
13+
from bqskit.compiler.registry import register_workflow
14+
from bqskit.compiler.workflow import Workflow
15+
from bqskit.compiler.workflow import WorkflowLike
16+
from bqskit.ir import Circuit
17+
from bqskit.ir import Gate
18+
from bqskit.ir.gates import CZGate
19+
from bqskit.ir.gates import HGate
20+
from bqskit.ir.gates import RZGate
21+
from bqskit.ir.gates import U3Gate
22+
from bqskit.passes import QSearchSynthesisPass
23+
from bqskit.passes import QuickPartitioner
24+
from bqskit.passes import ScanningGateRemovalPass
25+
26+
27+
def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool:
28+
if mach_a.num_qudits != mach_b.num_qudits:
29+
return False
30+
if mach_a.radixes != mach_b.radixes:
31+
return False
32+
if mach_a.coupling_graph != mach_b.coupling_graph:
33+
return False
34+
if mach_a.gate_set != mach_b.gate_set:
35+
return False
36+
return True
37+
38+
39+
def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool:
40+
return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5)
41+
42+
43+
def workflow_match(
44+
workflow_a: WorkflowLike,
45+
workflow_b: WorkflowLike,
46+
) -> bool:
47+
if not isinstance(workflow_a, Workflow):
48+
workflow_a = Workflow(workflow_a)
49+
if not isinstance(workflow_b, Workflow):
50+
workflow_b = Workflow(workflow_b)
51+
if len(workflow_a) != len(workflow_b):
52+
return False
53+
for a, b in zip(workflow_a, workflow_b):
54+
if a.name != b.name:
55+
return False
56+
return True
57+
58+
59+
def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit:
60+
circ = Circuit(num_qudits)
61+
gate = choice(gate_set)
62+
if gate.num_qudits == 1:
63+
loc = choice(range(num_qudits))
64+
else:
65+
loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore
66+
gate_inv = gate.get_inverse()
67+
circ.append_gate(gate, loc)
68+
circ.append_gate(gate_inv, loc)
69+
return circ
70+
71+
72+
class TestRegisterWorkflow:
73+
74+
@pytest.fixture(autouse=True)
75+
def setup(self) -> None:
76+
# global _compile_registry
77+
_compile_registry.clear()
78+
79+
def test_register_workflow(self) -> None:
80+
global _compile_registry
81+
assert _compile_registry == {}
82+
gateset = [CZGate(), HGate(), RZGate()]
83+
num_qudits = 3
84+
machine = MachineModel(num_qudits, gate_set=gateset)
85+
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
86+
register_workflow(machine, workflow, 1)
87+
assert machine in _compile_registry
88+
assert 1 in _compile_registry[machine]
89+
assert workflow_match(_compile_registry[machine][1], workflow)
90+
91+
def test_custom_compile_machine(self) -> None:
92+
global _compile_registry
93+
assert _compile_registry == {}
94+
gateset = [CZGate(), HGate(), RZGate()]
95+
num_qudits = 3
96+
machine = MachineModel(num_qudits, gate_set=gateset)
97+
workflow = [QuickPartitioner(2)]
98+
register_workflow(machine, workflow, 1)
99+
circuit = simple_circuit(num_qudits, gateset)
100+
result = compile(circuit, machine)
101+
assert unitary_match(result, circuit)
102+
assert result.num_operations > 0
103+
assert result.gate_counts != circuit.gate_counts
104+
result.unfold_all()
105+
assert result.gate_counts == circuit.gate_counts
106+
107+
def test_custom_opt_level(self) -> None:
108+
global _compile_registry
109+
assert _compile_registry == {}
110+
gateset = [CZGate(), HGate(), RZGate()]
111+
num_qudits = 3
112+
machine = MachineModel(num_qudits, gate_set=gateset)
113+
workflow = [QSearchSynthesisPass()]
114+
register_workflow(machine, workflow, 2)
115+
circuit = simple_circuit(num_qudits, gateset)
116+
result = compile(circuit, machine, optimization_level=2)
117+
assert unitary_match(result, circuit)
118+
assert result.gate_counts != circuit.gate_counts
119+
assert U3Gate() in result.gate_set

0 commit comments

Comments
 (0)