Skip to content

Commit

Permalink
Merge pull request #269 from BQSKit/register_workflow
Browse files Browse the repository at this point in the history
Register `Workflow`s
  • Loading branch information
mtweiden authored Aug 30, 2024
2 parents 52fae4b + 1f94658 commit 45293ae
Show file tree
Hide file tree
Showing 6 changed files with 213 additions and 0 deletions.
7 changes: 7 additions & 0 deletions bqskit/compiler/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from bqskit.compiler.compiler import Compiler
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.passdata import PassData
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir.circuit import Circuit
Expand Down Expand Up @@ -668,6 +669,12 @@ def build_workflow(
if model is None:
model = MachineModel(input.num_qudits, radixes=input.radixes)

# Use a registered workflow if model is found in the registry for a given
# optimization_level
if model in _compile_registry:
if optimization_level in _compile_registry[model]:
return _compile_registry[model][optimization_level]

if isinstance(input, Circuit):
if input.num_qudits > max_synthesis_size:
if any(
Expand Down
4 changes: 4 additions & 0 deletions bqskit/compiler/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,5 +231,9 @@ def __repr__(self) -> str:
"""Detailed representation of the GateSet."""
return self._gates.__repr__().replace('frozenset', 'GateSet')

def __hash__(self) -> int:
"""Hash of the GateSet."""
return hash(tuple(sorted([g.name for g in self._gates])))


GateSetLike = Union[GateSet, Iterable[Gate], Gate]
64 changes: 64 additions & 0 deletions bqskit/compiler/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Register MachineModel specific default workflows."""
from __future__ import annotations

import warnings

from bqskit.compiler.machine import MachineModel
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike


_compile_registry: dict[MachineModel, dict[int, Workflow]] = {}


def register_workflow(
key: MachineModel,
workflow: WorkflowLike,
optimization_level: int,
) -> None:
"""
Register a workflow for a given MachineModel.
The _compile_registry enables MachineModel specific workflows to be
registered for use in the `bqskit.compile` method. _compile_registry maps
MachineModels a dictionary of Workflows which are indexed by optimization
level. This object should not be accessed directly by the user, but
instead through the `register_workflow` function.
Args:
key (MachineModel): A MachineModel to register the workflow under.
If a circuit is compiled targeting this machine or gate set, the
registered workflow will be used.
workflow (list[BasePass]): The workflow or list of passes that will
be executed if the MachineModel in a call to `compile` matches
`key`. If `key` is already registered, a warning will be logged.
optimization_level ptional[int): The optimization level with which
to register the workflow. If no level is provided, the Workflow
will be registered as level 1.
Example:
model_t = SpecificMachineModel(num_qudits, radixes)
workflow = [QuickPartitioner(3), NewFangledOptimization()]
register_workflow(model_t, workflow, level)
...
new_circuit = compile(circuit, model_t, optimization_level=level)
Raises:
Warning: If a workflow for a given optimization_level is overwritten.
"""
workflow = Workflow(workflow)

global _compile_registry
new_workflow = {optimization_level: workflow}
if key in _compile_registry:
if optimization_level in _compile_registry[key]:
m = f'Overwritting workflow for {key} at level '
m += f'{optimization_level}. If multiple Namespace packages are '
m += 'installed, ensure that their __init__.py files do not '
m += 'attempt to overwrite the same default Workflows.'
warnings.warn(m)
_compile_registry[key].update(new_workflow)
else:
_compile_registry[key] = new_workflow
6 changes: 6 additions & 0 deletions bqskit/compiler/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ def name(self) -> str:
"""The name of the pass."""
return self._name or self.__class__.__name__

@staticmethod
def is_workflow(workflow: WorkflowLike) -> bool:
if not is_iterable(workflow):
return isinstance(workflow, BasePass)
return all(isinstance(p, BasePass) for p in workflow)

def __str__(self) -> str:
name_seq = f'Workflow: {self.name}\n\t'
pass_strs = [
Expand Down
13 changes: 13 additions & 0 deletions tests/compiler/test_gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,3 +522,16 @@ def test_gate_set_repr() -> None:
repr(gate_set) == 'GateSet({CNOTGate, U3Gate})'
or repr(gate_set) == 'GateSet({U3Gate, CNOTGate})'
)


def test_gate_set_hash() -> None:
gate_set_1 = GateSet({CNOTGate(), U3Gate()})
gate_set_2 = GateSet({U3Gate(), CNOTGate()})
gate_set_3 = GateSet({U3Gate(), CNOTGate(), RZGate()})

h1 = hash(gate_set_1)
h2 = hash(gate_set_2)
h3 = hash(gate_set_3)

assert h1 == h2
assert h1 != h3
119 changes: 119 additions & 0 deletions tests/compiler/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""This file tests the register_workflow function."""
from __future__ import annotations

from itertools import combinations
from random import choice

import pytest
from numpy import allclose

from bqskit.compiler.compile import compile
from bqskit.compiler.machine import MachineModel
from bqskit.compiler.registry import _compile_registry
from bqskit.compiler.registry import register_workflow
from bqskit.compiler.workflow import Workflow
from bqskit.compiler.workflow import WorkflowLike
from bqskit.ir import Circuit
from bqskit.ir import Gate
from bqskit.ir.gates import CZGate
from bqskit.ir.gates import HGate
from bqskit.ir.gates import RZGate
from bqskit.ir.gates import U3Gate
from bqskit.passes import QSearchSynthesisPass
from bqskit.passes import QuickPartitioner
from bqskit.passes import ScanningGateRemovalPass


def machine_match(mach_a: MachineModel, mach_b: MachineModel) -> bool:
if mach_a.num_qudits != mach_b.num_qudits:
return False
if mach_a.radixes != mach_b.radixes:
return False
if mach_a.coupling_graph != mach_b.coupling_graph:
return False
if mach_a.gate_set != mach_b.gate_set:
return False
return True


def unitary_match(unit_a: Circuit, unit_b: Circuit) -> bool:
return allclose(unit_a.get_unitary(), unit_b.get_unitary(), atol=1e-5)


def workflow_match(
workflow_a: WorkflowLike,
workflow_b: WorkflowLike,
) -> bool:
if not isinstance(workflow_a, Workflow):
workflow_a = Workflow(workflow_a)
if not isinstance(workflow_b, Workflow):
workflow_b = Workflow(workflow_b)
if len(workflow_a) != len(workflow_b):
return False
for a, b in zip(workflow_a, workflow_b):
if a.name != b.name:
return False
return True


def simple_circuit(num_qudits: int, gate_set: list[Gate]) -> Circuit:
circ = Circuit(num_qudits)
gate = choice(gate_set)
if gate.num_qudits == 1:
loc = choice(range(num_qudits))
else:
loc = choice(list(combinations(range(num_qudits), 2))) # type: ignore
gate_inv = gate.get_inverse()
circ.append_gate(gate, loc)
circ.append_gate(gate_inv, loc)
return circ


class TestRegisterWorkflow:

@pytest.fixture(autouse=True)
def setup(self) -> None:
# global _compile_registry
_compile_registry.clear()

def test_register_workflow(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(), ScanningGateRemovalPass()]
register_workflow(machine, workflow, 1)
assert machine in _compile_registry
assert 1 in _compile_registry[machine]
assert workflow_match(_compile_registry[machine][1], workflow)

def test_custom_compile_machine(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QuickPartitioner(2)]
register_workflow(machine, workflow, 1)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine)
assert unitary_match(result, circuit)
assert result.num_operations > 0
assert result.gate_counts != circuit.gate_counts
result.unfold_all()
assert result.gate_counts == circuit.gate_counts

def test_custom_opt_level(self) -> None:
global _compile_registry
assert _compile_registry == {}
gateset = [CZGate(), HGate(), RZGate()]
num_qudits = 3
machine = MachineModel(num_qudits, gate_set=gateset)
workflow = [QSearchSynthesisPass()]
register_workflow(machine, workflow, 2)
circuit = simple_circuit(num_qudits, gateset)
result = compile(circuit, machine, optimization_level=2)
assert unitary_match(result, circuit)
assert result.gate_counts != circuit.gate_counts
assert U3Gate() in result.gate_set

0 comments on commit 45293ae

Please sign in to comment.