|
1 | 1 | import collections |
| 2 | +from importlib.metadata import metadata |
2 | 3 | import multiprocessing |
3 | 4 | import pathlib |
4 | 5 | import tempfile |
@@ -79,6 +80,54 @@ def test_collect(): |
79 | 80 | assert d[0.03].errors <= 70 |
80 | 81 | assert 1 <= d[0.04].errors <= 100 |
81 | 82 |
|
| 83 | +def test_collect_postselection(): |
| 84 | + def postselect_all_detectors_predicate(index: int, metadata: any, coords: tuple) -> bool: |
| 85 | + return True |
| 86 | + |
| 87 | + tasks = [] |
| 88 | + for p in [0.01, 0.02, 0.03, 0.04]: |
| 89 | + circuit = stim.Circuit.generated( |
| 90 | + 'repetition_code:memory', |
| 91 | + rounds=3, |
| 92 | + distance=3, |
| 93 | + after_clifford_depolarization=p, |
| 94 | + ) |
| 95 | + mask = sinter._collection.post_selection_mask_from_predicate( |
| 96 | + circuit_or_dem=circuit, |
| 97 | + metadata={}, |
| 98 | + postselected_detectors_predicate=postselect_all_detectors_predicate, |
| 99 | + ) |
| 100 | + tasks.append(sinter.Task( |
| 101 | + circuit=circuit, |
| 102 | + decoder='pymatching', |
| 103 | + postselection_mask=mask, |
| 104 | + json_metadata={'p': p}, |
| 105 | + collection_options=sinter.CollectionOptions( |
| 106 | + max_shots=1000, |
| 107 | + max_errors=100, |
| 108 | + start_batch_size=100, |
| 109 | + max_batch_size=1000, |
| 110 | + ), |
| 111 | + )) |
| 112 | + |
| 113 | + results = sinter.collect( |
| 114 | + num_workers=2, |
| 115 | + tasks=tasks |
| 116 | + ) |
| 117 | + probabilities = [e.json_metadata['p'] for e in results] |
| 118 | + assert len(probabilities) == len(set(probabilities)) |
| 119 | + d = {e.json_metadata['p']: e for e in results} |
| 120 | + print(d) |
| 121 | + assert len(d) == 4 |
| 122 | + for k, v in d.items(): |
| 123 | + assert v.shots >= 1000 |
| 124 | + assert v.errors <= 1 # there is some small probability for undetected logical error |
| 125 | + assert d[0.01].discards <= 200 |
| 126 | + assert d[0.02].discards <= 300 |
| 127 | + assert d[0.03].discards <= 500 |
| 128 | + assert 100 <= d[0.04].discards <= 1000 |
| 129 | + |
| 130 | + |
82 | 131 |
|
83 | 132 | def test_collect_from_paths(): |
84 | 133 | with tempfile.TemporaryDirectory() as d: |
|
0 commit comments