Skip to content

Commit 02b8661

Browse files
committed
fix sinter post selection and add test
1 parent 365a1d8 commit 02b8661

File tree

2 files changed

+51
-2
lines changed

2 files changed

+51
-2
lines changed

glue/sample/src/sinter/_collection/_collection_test.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import collections
2+
from importlib.metadata import metadata
23
import multiprocessing
34
import pathlib
45
import tempfile
@@ -79,6 +80,54 @@ def test_collect():
7980
assert d[0.03].errors <= 70
8081
assert 1 <= d[0.04].errors <= 100
8182

83+
def test_collect_postselection():
84+
def postselect_all_detectors_predicate(index: int, metadata: any, coords: tuple) -> bool:
85+
return True
86+
87+
tasks = []
88+
for p in [0.01, 0.02, 0.03, 0.04]:
89+
circuit = stim.Circuit.generated(
90+
'repetition_code:memory',
91+
rounds=3,
92+
distance=3,
93+
after_clifford_depolarization=p,
94+
)
95+
mask = sinter._collection.post_selection_mask_from_predicate(
96+
circuit_or_dem=circuit,
97+
metadata={},
98+
postselected_detectors_predicate=postselect_all_detectors_predicate,
99+
)
100+
tasks.append(sinter.Task(
101+
circuit=circuit,
102+
decoder='pymatching',
103+
postselection_mask=mask,
104+
json_metadata={'p': p},
105+
collection_options=sinter.CollectionOptions(
106+
max_shots=1000,
107+
max_errors=100,
108+
start_batch_size=100,
109+
max_batch_size=1000,
110+
),
111+
))
112+
113+
results = sinter.collect(
114+
num_workers=2,
115+
tasks=tasks
116+
)
117+
probabilities = [e.json_metadata['p'] for e in results]
118+
assert len(probabilities) == len(set(probabilities))
119+
d = {e.json_metadata['p']: e for e in results}
120+
print(d)
121+
assert len(d) == 4
122+
for k, v in d.items():
123+
assert v.shots >= 1000
124+
assert v.errors <= 1 # there is some small probability for undetected logical error
125+
assert d[0.01].discards <= 200
126+
assert d[0.02].discards <= 300
127+
assert d[0.03].discards <= 500
128+
assert 100 <= d[0.04].discards <= 1000
129+
130+
82131

83132
def test_collect_from_paths():
84133
with tempfile.TemporaryDirectory() as d:

glue/sample/src/sinter/_decoding/_stim_then_decode_sampler.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,8 @@ def sample(self, max_shots: int) -> AnonTaskStats:
197197
raise ValueError("predictions.dtype != np.uint8")
198198
if len(predictions.shape) != 2:
199199
raise ValueError("len(predictions.shape) != 2")
200-
if predictions.shape[0] != num_shots:
201-
raise ValueError("predictions.shape[0] != num_shots")
200+
if predictions.shape[0] != num_shots - num_discards_1:
201+
raise ValueError("predictions.shape[0] != num_shots - num_discards_1")
202202
if predictions.shape[1] < actual_obs.shape[1]:
203203
raise ValueError("predictions.shape[1] < actual_obs.shape[1]")
204204
if predictions.shape[1] > actual_obs.shape[1] + 1:

0 commit comments

Comments
 (0)