Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a TreeScan Gate Removal pass to parallelize Scanning Gate #240

Merged
merged 14 commits into from
Aug 5, 2024
3 changes: 3 additions & 0 deletions bqskit/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ExhaustiveGateRemovalPass
IterativeScanningGateRemovalPass
ScanningGateRemovalPass
TreeScanningGateRemovalPass
SubstitutePass

.. rubric:: Retargeting Passes
Expand Down Expand Up @@ -257,6 +258,7 @@
from bqskit.passes.processing.iterative import IterativeScanningGateRemovalPass
from bqskit.passes.processing.scan import ScanningGateRemovalPass
from bqskit.passes.processing.substitute import SubstitutePass
from bqskit.passes.processing.treescan import TreeScanningGateRemovalPass
from bqskit.passes.retarget.auto import AutoRebase2QuditGatePass
from bqskit.passes.retarget.general import GeneralSQDecomposition
from bqskit.passes.retarget.two import Rebase2QuditGatePass
Expand Down Expand Up @@ -329,6 +331,7 @@
'UpdateDataPass',
'ToU3Pass',
'ScanningGateRemovalPass',
'TreeScanningGateRemovalPass',
edyounis marked this conversation as resolved.
Show resolved Hide resolved
'SimpleLayerGenerator',
'AStarHeuristic',
'GreedyHeuristic',
Expand Down
2 changes: 2 additions & 0 deletions bqskit/passes/processing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
from bqskit.passes.processing.iterative import IterativeScanningGateRemovalPass
from bqskit.passes.processing.scan import ScanningGateRemovalPass
from bqskit.passes.processing.substitute import SubstitutePass
from bqskit.passes.processing.treescan import TreeScanningGateRemovalPass

__all__ = [
'ExhaustiveGateRemovalPass',
'IterativeScanningGateRemovalPass',
'ScanningGateRemovalPass',
'SubstitutePass',
'TreeScanningGateRemovalPass',
]
226 changes: 226 additions & 0 deletions bqskit/passes/processing/treescan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,226 @@
"""This module implements the ScanningGateRemovalPass."""
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
from __future__ import annotations

import logging
from typing import Any
from typing import Callable

from bqskit.compiler.basepass import BasePass
from bqskit.compiler.passdata import PassData
from bqskit.ir.circuit import Circuit
from bqskit.ir.operation import Operation
from bqskit.ir.opt.cost.functions import HilbertSchmidtResidualsGenerator
from bqskit.ir.opt.cost.generator import CostFunctionGenerator
from bqskit.runtime import get_runtime
from bqskit.utils.typing import is_integer
from bqskit.utils.typing import is_real_number

_logger = logging.getLogger(__name__)


class TreeScanningGateRemovalPass(BasePass):
"""
The ScanningGateRemovalPass class.
edyounis marked this conversation as resolved.
Show resolved Hide resolved

Starting from one side of the circuit, run the following:

Split the circuit operations into chunks of size tree_depth
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
At every iteration:
a. Look at the next chunk of operations
b. Generate 2 ^ tree_depth circuits. Each circuit corresponds to every
combination of whether or not to include one of the operations in the chunk.
c. Instantiate in parallel all 2^tree_depth circuits
d. Choose the circuit that has the least number of operations and move
on to the next chunk of operations.

This optimization is less greedy than the current ScanningGate removal,
edyounis marked this conversation as resolved.
Show resolved Hide resolved
which we see can offermuch better quality circuits than ScanningGate.
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
In very rare occasions, ScanningGate may be able to outperform
edyounis marked this conversation as resolved.
Show resolved Hide resolved
TreeScan (since it is still greedy), but in general we can expect
TreeScan to almost always outperform ScanningGate.
"""

def __init__(
self,
start_from_left: bool = True,
success_threshold: float = 1e-8,
cost: CostFunctionGenerator = HilbertSchmidtResidualsGenerator(),
instantiate_options: dict[str, Any] = {},
tree_depth: int = 1,
collection_filter: Callable[[Operation], bool] | None = None,
) -> None:
"""
Construct a ScanningGateRemovalPass.
edyounis marked this conversation as resolved.
Show resolved Hide resolved

Args:
start_from_left (bool): Determines where the scan starts
attempting to remove gates from. If True, scan goes left
to right, otherwise right to left. (Default: True)

success_threshold (float): The distance threshold that
determines successful termintation. Measured in cost
described by the hilbert schmidt cost function.
(Default: 1e-8)

cost (CostFunction | None): The cost function that determines
successful removal of a gate.
(Default: HilbertSchmidtResidualsGenerator())

instantiate_options (dict[str: Any]): Options passed directly
to circuit.instantiate when instantiating circuit
templates. (Default: {})

tree_depth (int): The depth of the tree of potential
solutions to instantiate. Note that 2^(tree_depth) - 1
circuits will be instantiated in parallel. Note that the default
behavior will be equivalent to normal ScanningGateRemoval
(Default: 1)

collection_filter (Callable[[Operation], bool] | None):
A predicate that determines which operations should be
attempted to be removed. Called with each operation
in the circuit. If this returns true, this pass will
attempt to remove that operation. Defaults to all
operations.
"""

if not is_real_number(success_threshold):
raise TypeError(
'Expected real number for success_threshold'
', got %s' % type(success_threshold),
)

if not isinstance(cost, CostFunctionGenerator):
raise TypeError(
'Expected cost to be a CostFunctionGenerator, got %s'
% type(cost),
)

if not isinstance(instantiate_options, dict):
raise TypeError(
'Expected dictionary for instantiate_options, got %s.'
% type(instantiate_options),
)

self.collection_filter = collection_filter or default_collection_filter

if not callable(self.collection_filter):
raise TypeError(
'Expected callable method that maps Operations to booleans for'
' collection_filter, got %s.' % type(self.collection_filter),
)

if not is_integer(tree_depth):
raise TypeError(
'Expected Integer type for tree_depth, got %s.'
% type(instantiate_options),
)

self.tree_depth = tree_depth
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
self.start_from_left = start_from_left
self.success_threshold = success_threshold
self.cost = cost
self.instantiate_options: dict[str, Any] = {
'dist_tol': self.success_threshold,
'min_iters': 10,
'cost_fn_gen': self.cost,
}
self.instantiate_options.update(instantiate_options)

@staticmethod
def get_tree_circs(
orig_num_cycles: int,
circuit_copy: Circuit,
cycle_and_ops: list[tuple[int, Operation]],
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
) -> list[Circuit]:
'''
Given a circuit, create 2^(tree_depth) - 1 circuits that remove up
to tree_depth operations. The circuits are sorted by the number of
operations removed.
'''
edyounis marked this conversation as resolved.
Show resolved Hide resolved
all_circs = [circuit_copy.copy()]
for cycle, op in cycle_and_ops:
new_circs = []
for circ in all_circs:
idx_shift = orig_num_cycles - circ.num_cycles
new_cycle = cycle - idx_shift
work_copy = circ.copy()
work_copy.pop((new_cycle, op.location[0]))
new_circs.append(work_copy)
new_circs.append(circ)

all_circs = new_circs

all_circs = sorted(all_circs, key=lambda x: x.num_operations)
# Remove circuit with no gates deleted
return all_circs[:-1]

async def run(self, circuit: Circuit, data: PassData) -> None:
"""Perform the pass's operation, see :class:`BasePass` for more."""
instantiate_options = self.instantiate_options.copy()
if 'seed' not in instantiate_options:
instantiate_options['seed'] = data.seed

start = 'left' if self.start_from_left else 'right'
_logger.debug(f'Starting scanning gate removal on the {start}.')
edyounis marked this conversation as resolved.
Show resolved Hide resolved

target = self.get_target(circuit, data)
# target = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary line


circuit_copy = circuit.copy()
reverse_iter = not self.start_from_left

ops_left = list(circuit.operations_with_cycles(reverse=reverse_iter))
print(
f'Starting Scan with tree depth {self.tree_depth}'
' on circuit with {len(ops_left)} gates',
edyounis marked this conversation as resolved.
Show resolved Hide resolved
)

while ops_left:
chunk = ops_left[:self.tree_depth]
ops_left = ops_left[self.tree_depth:]

# Circuits of size 2 ** tree_depth - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Circuits of size" is a bit misleading, it might be better to say "2 ** tree_depth - 1 Circuits".

Moreover, this comment would be redundant with a proper docstring for get_tree_circs

# ranked in order of most to fewest deletions
all_circs = TreeScanningGateRemovalPass.get_tree_circs(
circuit.num_cycles, circuit_copy, chunk,
)

_logger.debug(
'Attempting removal of operation of up to'
f' {self.tree_depth} operations.',
jkalloor3 marked this conversation as resolved.
Show resolved Hide resolved
)

instantiated_circuits: list[Circuit] = await get_runtime().map(
Circuit.instantiate,
all_circs,
target=target,
**instantiate_options,
)

dists = [self.cost(c, target) for c in instantiated_circuits]

# Pick least count with least dist
for i, dist in enumerate(dists):
if dist < self.success_threshold:
# Log gates removed
gate_dict_orig = circuit_copy.gate_counts
gate_dict_new = instantiated_circuits[i].gate_counts
gates_removed = {
k: circuit_copy.gate_counts[k] - gate_dict_new.get(k, 0)
for k in gate_dict_orig.keys()
}
gates_removed = {
k: v for k, v in gates_removed.items() if v != 0
}
_logger.debug(
f'Successfully removed {gates_removed} gates',
)
circuit_copy = instantiated_circuits[i]
break

circuit.become(circuit_copy)


def default_collection_filter(op: Operation) -> bool:
return True
Loading