@@ -15,7 +15,7 @@ def mwpf_import_error() -> ImportError:
1515 return ImportError (
1616 "The decoder 'MWPF' isn't installed\n "
1717 "To fix this, install the python package 'MWPF' into your environment.\n "
18- "For example, if you are using pip, run `pip install MWPF~=0.1.1 `.\n "
18+ "For example, if you are using pip, run `pip install MWPF~=0.1.5 `.\n "
1919 )
2020
2121
@@ -75,12 +75,18 @@ def compile_decoder_for_dem(
7575 # For example, `SolverSerialUnionFind` is the most basic solver without any plugin: it only
7676 # grows the clusters until the first valid solution appears; some more optimized solvers uses
7777 # one or more plugins to further optimize the solution, which requires longer decoding time.
78+ cluster_node_limit : int = 50 , # The maximum number of nodes in a cluster.
7879 ) -> CompiledDecoder :
7980 solver , fault_masks = detector_error_model_to_mwpf_solver_and_fault_masks (
80- dem , decoder_cls = decoder_cls
81+ dem ,
82+ decoder_cls = decoder_cls ,
83+ cluster_node_limit = cluster_node_limit ,
8184 )
8285 return MwpfCompiledDecoder (
83- solver , fault_masks , dem .num_detectors , dem .num_observables
86+ solver ,
87+ fault_masks ,
88+ dem .num_detectors ,
89+ dem .num_observables ,
8490 )
8591
8692 def decode_via_files (
@@ -220,26 +226,31 @@ def _helper(m: stim.DetectorErrorModel, reps: int):
220226def deduplicate_hyperedges (
221227 hyperedges : List [Tuple [List [int ], float , int ]]
222228) -> List [Tuple [List [int ], float , int ]]:
223- indices : dict [frozenset [int ], int ] = dict ()
229+ indices : dict [frozenset [int ], Tuple [ int , float ] ] = dict ()
224230 result : List [Tuple [List [int ], float , int ]] = []
225231 for dets , weight , mask in hyperedges :
226232 dets_set = frozenset (dets )
227233 if dets_set in indices :
228- idx = indices [dets_set ]
234+ idx , min_weight = indices [dets_set ]
229235 p1 = 1 / (1 + math .exp (weight ))
230236 p2 = 1 / (1 + math .exp (result [idx ][1 ]))
231237 p = p1 * (1 - p2 ) + p2 * (1 - p1 )
232- # not sure why would this fail? two hyperedges with different masks?
233- # assert mask == result[idx][2], (result[idx], (dets, weight, mask))
234- result [idx ] = (dets , math .log ((1 - p ) / p ), result [idx ][2 ])
238+ # choosing the mask from the most likely error
239+ new_mask = result [idx ][2 ]
240+ if weight < min_weight :
241+ indices [dets_set ] = (idx , weight )
242+ new_mask = mask
243+ result [idx ] = (dets , math .log ((1 - p ) / p ), new_mask )
235244 else :
236- indices [dets_set ] = len (result )
245+ indices [dets_set ] = ( len (result ), weight )
237246 result .append ((dets , weight , mask ))
238247 return result
239248
240249
241250def detector_error_model_to_mwpf_solver_and_fault_masks (
242- model : stim .DetectorErrorModel , decoder_cls : Any = None
251+ model : stim .DetectorErrorModel ,
252+ decoder_cls : Any = None ,
253+ cluster_node_limit : int = 50 ,
243254) -> Tuple [Optional ["mwpf.SolverSerialJointSingleHair" ], np .ndarray ]:
244255 """Convert a stim error model into a NetworkX graph."""
245256
@@ -261,7 +272,7 @@ def handle_error(p: float, dets: List[int], frame_changes: List[int]):
261272 # Accept it and keep going, though of course decoding will probably perform terribly.
262273 return
263274 if p > 0.5 :
264- # mwpf doesn't support negative edge weights.
275+ # mwpf doesn't support negative edge weights (yet, will be supported in the next version) .
265276 # approximate them as weight 0.
266277 p = 0.5
267278 weight = math .log ((1 - p ) / p )
@@ -280,7 +291,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
280291 # mwpf package panic on duplicate edges, thus we need to handle them here
281292 hyperedges = deduplicate_hyperedges (hyperedges )
282293
283- # fix the input by connecting an edge to all isolated vertices
294+ # fix the input by connecting an edge to all isolated vertices; will be supported in the next version
284295 for idx in range (num_detectors ):
285296 if not is_detector_connected [idx ]:
286297 hyperedges .append (([idx ], 0 , 0 ))
@@ -301,7 +312,7 @@ def handle_detector_coords(detector: int, coords: np.ndarray):
301312 decoder_cls = mwpf .SolverSerialJointSingleHair
302313 return (
303314 (
304- decoder_cls (initializer )
315+ decoder_cls (initializer , config = { "cluster_node_limit" : cluster_node_limit } )
305316 if num_detectors > 0 and len (rescaled_edges ) > 0
306317 else None
307318 ),
0 commit comments