Skip to content

Commit

Permalink
Register workflows by target type
Browse files Browse the repository at this point in the history
  • Loading branch information
mtweiden committed Sep 3, 2024
1 parent 45293ae commit 6bc2b13
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 28 deletions.
31 changes: 24 additions & 7 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -669,12 +672,6 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
if any(
Expand All @@ -691,6 +688,11 @@ def build_workflow(
'Unable to compile circuit with gate larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the circuit registry
# for a given optimization_level
if model in _compile_circuit_registry:
if optimization_level in _compile_circuit_registry[model]:
return _compile_circuit_registry[model][optimization_level]

return _circuit_workflow(
model,
Expand All @@ -708,6 +710,11 @@ def build_workflow(
'Unable to compile unitary with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the unitary registry
# for a given optimization_level
if model in _compile_unitary_registry:
if optimization_level in _compile_unitary_registry[model]:
return _compile_unitary_registry[model][optimization_level]

return _synthesis_workflow(
input,
Expand All @@ -726,6 +733,11 @@ def build_workflow(
'Unable to compile states with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the stateprep registry
# for a given optimization_level
if model in _compile_stateprep_registry:
if optimization_level in _compile_stateprep_registry[model]:
return _compile_stateprep_registry[model][optimization_level]

return _stateprep_workflow(
input,
Expand All @@ -744,6 +756,11 @@ def build_workflow(
'Unable to compile state systems with size larger than'
' max_synthesis_size.\nConsider adjusting it.',
)
# Use a registered workflow if model is found in the statemap registry
# for a given optimization_level
if model in _compile_statemap_registry:
if optimization_level in _compile_statemap_registry[model]:
return _compile_statemap_registry[model][optimization_level]

return _statemap_workflow(
input,
Expand Down
40 changes: 34 additions & 6 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_circuit_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_unitary_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_stateprep_registry: dict[MachineModel, dict[int, Workflow]] = {}
_compile_statemap_registry: dict[MachineModel, dict[int, Workflow]] = {}


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
target_type: str,
) -> None:
"""
Register a workflow for a given MachineModel.
Expand All @@ -34,10 +38,13 @@ def register_workflow(
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.
optimization_level ptional[int): The optimization level with which
optimization_level (Optional[int]): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.
target_type (str): Register a workflow for targets of this type. Must
be 'circuit', 'unitary', 'stateprep', or 'statemap'.
Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
Expand All @@ -47,17 +54,38 @@ def register_workflow(
Raises:
Warning: If a workflow for a given optimization_level is overwritten.
ValueError: If `target_type` is not 'circuit', 'unitary', 'stateprep',
or 'statemap'.
"""
if target_type not in ['circuit', 'unitary', 'stateprep', 'statemap']:
m = 'target_type must be "circuit", "unitary", "stateprep", or '
m += f'"statemap", got {target_type}.'
raise ValueError(m)

if target_type == 'circuit':
global _compile_circuit_registry
_compile_registry = _compile_circuit_registry
elif target_type == 'unitary':
global _compile_unitary_registry
_compile_registry = _compile_unitary_registry
elif target_type == 'stateprep':
global _compile_stateprep_registry
_compile_registry = _compile_stateprep_registry
else:
global _compile_statemap_registry
_compile_registry = _compile_statemap_registry

workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
m += f'{optimization_level} for target type {target_type}.'
m += 'If multiple Namespace packages are installed, ensure'
m += 'that their __init__.py files do not attempt to'
m += 'overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
Expand Down
63 changes: 48 additions & 15 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import _compile_circuit_registry
from bqskit.compiler.registry import _compile_statemap_registry
from bqskit.compiler.registry import _compile_stateprep_registry
from bqskit.compiler.registry import _compile_unitary_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
Expand Down Expand Up @@ -74,28 +77,58 @@ class TestRegisterWorkflow:
@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()
_compile_circuit_registry.clear()
_compile_unitary_registry.clear()
_compile_statemap_registry.clear()
_compile_stateprep_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
global _compile_unitary_registry
global _compile_statemap_registry
global _compile_stateprep_registry
assert _compile_circuit_registry == {}
assert _compile_unitary_registry == {}
assert _compile_statemap_registry == {}
assert _compile_stateprep_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)
circuit_workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
other_workflow = [QuickPartitioner(), QSearchSynthesisPass()]
register_workflow(machine, circuit_workflow, 1, 'circuit')
register_workflow(machine, other_workflow, 1, 'unitary')
register_workflow(machine, other_workflow, 1, 'statemap')
register_workflow(machine, other_workflow, 1, 'stateprep')
assert machine in _compile_circuit_registry
assert 1 in _compile_circuit_registry[machine]
assert workflow_match(
_compile_circuit_registry[machine][1], circuit_workflow,
)
assert machine in _compile_unitary_registry
assert 1 in _compile_unitary_registry[machine]
assert workflow_match(
_compile_unitary_registry[machine][1], other_workflow,
)
assert machine in _compile_statemap_registry
assert 1 in _compile_statemap_registry[machine]
assert workflow_match(
_compile_statemap_registry[machine][1], other_workflow,
)
assert machine in _compile_stateprep_registry
assert 1 in _compile_stateprep_registry[machine]
assert workflow_match(
_compile_stateprep_registry[machine][1], other_workflow,
)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
register_workflow(machine, workflow, 1, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
Expand All @@ -105,13 +138,13 @@ def test_custom_compile_machine(self) -> None:
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
global _compile_circuit_registry
assert _compile_circuit_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
register_workflow(machine, workflow, 2, 'circuit')
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
Expand Down

0 comments on commit 6bc2b13

Please sign in to comment.