-
Notifications
You must be signed in to change notification settings - Fork 38
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #269 from BQSKit/register_workflow
Register `Workflow`s
- Loading branch information
Showing
6 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
"""Register MachineModel specific default workflows.""" | ||
from __future__ import annotations | ||
|
||
import warnings | ||
|
||
from bqskit.compiler.machine import MachineModel | ||
from bqskit.compiler.workflow import Workflow | ||
from bqskit.compiler.workflow import WorkflowLike | ||
|
||
|
||
_compile_registry: dict[MachineModel, dict[int, Workflow]] = {} | ||
|
||
|
||
def register_workflow( | ||
key: MachineModel, | ||
workflow: WorkflowLike, | ||
optimization_level: int, | ||
) -> None: | ||
""" | ||
Register a workflow for a given MachineModel. | ||
The _compile_registry enables MachineModel specific workflows to be | ||
registered for use in the `bqskit.compile` method. _compile_registry maps | ||
MachineModels a dictionary of Workflows which are indexed by optimization | ||
level. This object should not be accessed directly by the user, but | ||
instead through the `register_workflow` function. | ||
Args: | ||
key (MachineModel): A MachineModel to register the workflow under. | ||
If a circuit is compiled targeting this machine or gate set, the | ||
registered workflow will be used. | ||
workflow (list[BasePass]): The workflow or list of passes that will | ||
be executed if the MachineModel in a call to `compile` matches | ||
`key`. If `key` is already registered, a warning will be logged. | ||
optimization_level ptional[int): The optimization level with which | ||
to register the workflow. If no level is provided, the Workflow | ||
will be registered as level 1. | ||
Example: | ||
model_t = SpecificMachineModel(num_qudits, radixes) | ||
workflow = [QuickPartitioner(3), NewFangledOptimization()] | ||
register_workflow(model_t, workflow, level) | ||
... | ||
new_circuit = compile(circuit, model_t, optimization_level=level) | ||
Raises: | ||
Warning: If a workflow for a given optimization_level is overwritten. | ||
""" | ||
workflow = Workflow(workflow) | ||
|
||
global _compile_registry | ||
new_workflow = {optimization_level: workflow} | ||
if key in _compile_registry: | ||
if optimization_level in _compile_registry[key]: | ||
m = f'Overwritting workflow for {key} at level ' | ||
m += f'{optimization_level}. If multiple Namespace packages are ' | ||
m += 'installed, ensure that their __init__.py files do not ' | ||
m += 'attempt to overwrite the same default Workflows.' | ||
warnings.warn(m) | ||
_compile_registry[key].update(new_workflow) | ||
else: | ||
_compile_registry[key] = new_workflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""This file tests the register_workflow function.""" | ||
from __future__ import annotations | ||
|
||
from itertools import combinations | ||
from random import choice | ||
|
||
import pytest | ||
from numpy import allclose | ||
|
||
from bqskit.compiler.compile import compile | ||
from bqskit.compiler.machine import MachineModel | ||
from bqskit.compiler.registry import _compile_registry | ||
from bqskit.compiler.registry import register_workflow | ||
from bqskit.compiler.workflow import Workflow | ||
from bqskit.compiler.workflow import WorkflowLike | ||
from bqskit.ir import Circuit | ||
from bqskit.ir import Gate | ||
from bqskit.ir.gates import CZGate | ||
from bqskit.ir.gates import HGate | ||
from bqskit.ir.gates import RZGate | ||
from bqskit.ir.gates import U3Gate | ||
from bqskit.passes import QSearchSynthesisPass | ||
from bqskit.passes import QuickPartitioner | ||
from bqskit.passes import ScanningGateRemovalPass | ||
|
||
|
||
def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool: | ||
if mach_a.num_qudits != mach_b.num_qudits: | ||
return False | ||
if mach_a.radixes != mach_b.radixes: | ||
return False | ||
if mach_a.coupling_graph != mach_b.coupling_graph: | ||
return False | ||
if mach_a.gate_set != mach_b.gate_set: | ||
return False | ||
return True | ||
|
||
|
||
def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool: | ||
return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5) | ||
|
||
|
||
def workflow_match( | ||
workflow_a: WorkflowLike, | ||
workflow_b: WorkflowLike, | ||
) -> bool: | ||
if not isinstance(workflow_a, Workflow): | ||
workflow_a = Workflow(workflow_a) | ||
if not isinstance(workflow_b, Workflow): | ||
workflow_b = Workflow(workflow_b) | ||
if len(workflow_a) != len(workflow_b): | ||
return False | ||
for a, b in zip(workflow_a, workflow_b): | ||
if a.name != b.name: | ||
return False | ||
return True | ||
|
||
|
||
def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit: | ||
circ = Circuit(num_qudits) | ||
gate = choice(gate_set) | ||
if gate.num_qudits == 1: | ||
loc = choice(range(num_qudits)) | ||
else: | ||
loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore | ||
gate_inv = gate.get_inverse() | ||
circ.append_gate(gate, loc) | ||
circ.append_gate(gate_inv, loc) | ||
return circ | ||
|
||
|
||
class TestRegisterWorkflow: | ||
|
||
@pytest.fixture(autouse=True) | ||
def setup(self) -> None: | ||
# global _compile_registry | ||
_compile_registry.clear() | ||
|
||
def test_register_workflow(self) -> None: | ||
global _compile_registry | ||
assert _compile_registry == {} | ||
gateset = [CZGate(), HGate(), RZGate()] | ||
num_qudits = 3 | ||
machine = MachineModel(num_qudits, gate_set=gateset) | ||
workflow = [QuickPartitioner(), ScanningGateRemovalPass()] | ||
register_workflow(machine, workflow, 1) | ||
assert machine in _compile_registry | ||
assert 1 in _compile_registry[machine] | ||
assert workflow_match(_compile_registry[machine][1], workflow) | ||
|
||
def test_custom_compile_machine(self) -> None: | ||
global _compile_registry | ||
assert _compile_registry == {} | ||
gateset = [CZGate(), HGate(), RZGate()] | ||
num_qudits = 3 | ||
machine = MachineModel(num_qudits, gate_set=gateset) | ||
workflow = [QuickPartitioner(2)] | ||
register_workflow(machine, workflow, 1) | ||
circuit = simple_circuit(num_qudits, gateset) | ||
result = compile(circuit, machine) | ||
assert unitary_match(result, circuit) | ||
assert result.num_operations > 0 | ||
assert result.gate_counts != circuit.gate_counts | ||
result.unfold_all() | ||
assert result.gate_counts == circuit.gate_counts | ||
|
||
def test_custom_opt_level(self) -> None: | ||
global _compile_registry | ||
assert _compile_registry == {} | ||
gateset = [CZGate(), HGate(), RZGate()] | ||
num_qudits = 3 | ||
machine = MachineModel(num_qudits, gate_set=gateset) | ||
workflow = [QSearchSynthesisPass()] | ||
register_workflow(machine, workflow, 2) | ||
circuit = simple_circuit(num_qudits, gateset) | ||
result = compile(circuit, machine, optimization_level=2) | ||
assert unitary_match(result, circuit) | ||
assert result.gate_counts != circuit.gate_counts | ||
assert U3Gate() in result.gate_set |