Skip to content

Commit a3e080c

Browse files
authored
add cluster_node_limit for MWPF decoder to better tune decoding time and accuracy (#857)
This is a new parameter that is agnostic to individual machine's clock speed. We limit the maximum number of dual variables inside each cluster to avoid wasting computing time on small yet complicated clusters. A default value of 50 is good enough for small code and improves the decoding speed a lot. Intuitively (but not precisely), this means we limit the maximum number of dual variables (blossoms and their children) in an alternating tree to 50, and once it hits this limit, it will fall back to union-find decoder. Of course, in hypergraph cases, the situation is a little bit more complicated and this limit is hit more often.
1 parent e6fd563 commit a3e080c

File tree

2 files changed

+25
-14
lines changed

2 files changed

+25
-14
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -431,7 +431,7 @@ jobs:
431431
- run: bazel build :stim_dev_wheel
432432
- run: pip install bazel-bin/stim-0.0.dev0-py3-none-any.whl
433433
- run: pip install -e glue/sample
434-
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.1
434+
- run: pip install pytest pymatching fusion-blossom~=0.1.4 mwpf~=0.1.5
435435
- run: pytest glue/sample
436436
- run: dev/doctest_proper.py --module sinter
437437
- run: sinter help

glue/sample/src/sinter/_decoding/_decoding_mwpf.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
1515
return ImportError(
1616
"The decoder 'MWPF' isn't installed\n"
1717
"To fix this, install the python package 'MWPF' into your environment.\n"
18-
"For example, if you are using pip, run `pip install MWPF~=0.1.1`.\n"
18+
"For example, if you are using pip, run `pip install MWPF~=0.1.5`.\n"
1919
)
2020

2121

@@ -75,12 +75,18 @@ def compile_decoder_for_dem(
7575
# For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
7676
# grows the clusters until the first valid solution appears; some more optimized solvers uses
7777
# one or more plugins to further optimize the solution, which requires longer decoding time.
78+
cluster_node_limit: int = 50, # The maximum number of nodes in a cluster.
7879
) -> CompiledDecoder:
7980
solver, fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks(
80-
dem, decoder_cls=decoder_cls
81+
dem,
82+
decoder_cls=decoder_cls,
83+
cluster_node_limit=cluster_node_limit,
8184
)
8285
return MwpfCompiledDecoder(
83-
solver, fault_masks, dem.num_detectors, dem.num_observables
86+
solver,
87+
fault_masks,
88+
dem.num_detectors,
89+
dem.num_observables,
8490
)
8591

8692
def decode_via_files(
@@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
220226
def deduplicate_hyperedges(
221227
hyperedges: List[Tuple[List[int], float, int]]
222228
) -> List[Tuple[List[int], float, int]]:
223-
indices: dict[frozenset[int], int] = dict()
229+
indices: dict[frozenset[int], Tuple[int, float]] = dict()
224230
result: List[Tuple[List[int], float, int]] = []
225231
for dets, weight, mask in hyperedges:
226232
dets_set = frozenset(dets)
227233
if dets_set in indices:
228-
idx = indices[dets_set]
234+
idx, min_weight = indices[dets_set]
229235
p1 = 1 / (1 + math.exp(weight))
230236
p2 = 1 / (1 + math.exp(result[idx][1]))
231237
p = p1 * (1 - p2) + p2 * (1 - p1)
232-
# not sure why would this fail? two hyperedges with different masks?
233-
# assert mask == result[idx][2], (result[idx], (dets, weight, mask))
234-
result[idx] = (dets, math.log((1 - p) / p), result[idx][2])
238+
# choosing the mask from the most likely error
239+
new_mask = result[idx][2]
240+
if weight < min_weight:
241+
indices[dets_set] = (idx, weight)
242+
new_mask = mask
243+
result[idx] = (dets, math.log((1 - p) / p), new_mask)
235244
else:
236-
indices[dets_set] = len(result)
245+
indices[dets_set] = (len(result), weight)
237246
result.append((dets, weight, mask))
238247
return result
239248

240249

241250
def detector_error_model_to_mwpf_solver_and_fault_masks(
242-
model: stim.DetectorErrorModel, decoder_cls: Any = None
251+
model: stim.DetectorErrorModel,
252+
decoder_cls: Any = None,
253+
cluster_node_limit: int = 50,
243254
) -> Tuple[Optional["mwpf.SolverSerialJointSingleHair"], np.ndarray]:
244255
"""Convert a stim error model into a NetworkX graph."""
245256

@@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
261272
# Accept it and keep going, though of course decoding will probably perform terribly.
262273
return
263274
if p > 0.5:
264-
# mwpf doesn't support negative edge weights.
275+
# mwpf doesn't support negative edge weights (yet, will be supported in the next version).
265276
# approximate them as weight 0.
266277
p = 0.5
267278
weight = math.log((1 - p) / p)
@@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
280291
# mwpf package panic on duplicate edges, thus we need to handle them here
281292
hyperedges = deduplicate_hyperedges(hyperedges)
282293

283-
# fix the input by connecting an edge to all isolated vertices
294+
# fix the input by connecting an edge to all isolated vertices; will be supported in the next version
284295
for idx in range(num_detectors):
285296
if not is_detector_connected[idx]:
286297
hyperedges.append(([idx], 0, 0))
@@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
301312
decoder_cls = mwpf.SolverSerialJointSingleHair
302313
return (
303314
(
304-
decoder_cls(initializer)
315+
decoder_cls(initializer, config={"cluster_node_limit": cluster_node_limit})
305316
if num_detectors > 0 and len(rescaled_edges) > 0
306317
else None
307318
),

0 commit comments

Comments
 (0)