Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions glue/sample/src/sinter/_decoding/_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ def sample_decode(*,
decoder: The name of the decoder to use. Allowed values are:
"pymatching":
Use pymatching min-weight-perfect-match decoder.
"correlated_pymatching":
Use two-pass correlated pymatching decoder.
"fusion_blossom":
Use fusion blossom min-weight-perfect-match decoder.
"hypergraph_union_find":
Use weighted hypergraph union-find decoder.
"mw_parity_factor":
Use mwpf min-weight-parity-factor decoder.
"internal":
Use internal decoder with uncorrelated decoding.
"internal_correlated":
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict
from typing import Union

from sinter._decoding._decoding_correlated_pymatching import CorrelatedPyMatchingDecoder
from sinter._decoding._decoding_decoder_class import Decoder
from sinter._decoding._decoding_fusion_blossom import FusionBlossomDecoder
from sinter._decoding._decoding_pymatching import PyMatchingDecoder
Expand All @@ -12,6 +13,7 @@
BUILT_IN_DECODERS: Dict[str, Decoder] = {
'vacuous': VacuousDecoder(),
'pymatching': PyMatchingDecoder(),
"correlated_pymatching": CorrelatedPyMatchingDecoder(),
'fusion_blossom': FusionBlossomDecoder(),
# an implementation of (weighted) hypergraph UF decoder (https://arxiv.org/abs/2103.08049)
'hypergraph_union_find': HyperUFDecoder(),
Expand Down
106 changes: 106 additions & 0 deletions glue/sample/src/sinter/_decoding/_decoding_correlated_pymatching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from packaging import version

from sinter._decoding._decoding_decoder_class import Decoder, CompiledDecoder


class CorrelatedPyMatchingCompiledDecoder(CompiledDecoder):
def __init__(self, matcher: "pymatching.Matching"):
self.matcher = matcher

def decode_shots_bit_packed(
self,
*,
bit_packed_detection_event_data: "np.ndarray",
) -> "np.ndarray":
return self.matcher.decode_batch(
shots=bit_packed_detection_event_data,
bit_packed_shots=True,
bit_packed_predictions=True,
return_weights=False,
)


class CorrelatedPyMatchingDecoder(Decoder):
"""Use correlated pymatching to predict observables from detection events."""

def compile_decoder_for_dem(
self, *, dem: "stim.DetectorErrorModel"
) -> CompiledDecoder:
try:
import pymatching
except ImportError as ex:
raise ImportError(
"The decoder 'pymatching' isn't installed\n"
"To fix this, install the python package 'pymatching' into your environment.\n"
"For example, if you are using pip, run `pip install pymatching`.\n"
) from ex

# correlated matching requires pymatching 2.3.1 or later
if version.parse(pymatching.__version__) < version.parse("2.3.1"):
raise ValueError("""
The correlated pymatching decoder requires pymatching 2.3.1 or later.

If you're using pip to install packages, this can be fixed by running
```
pip install "pymatching~=2.3.1" --upgrade
```
""")

return CorrelatedPyMatchingCompiledDecoder(
pymatching.Matching.from_detector_error_model(dem, enable_correlations=True)
)

def decode_via_files(
self,
*,
num_shots: int,
num_dets: int,
num_obs: int,
dem_path: "pathlib.Path",
dets_b8_in_path: "pathlib.Path",
obs_predictions_b8_out_path: "pathlib.Path",
tmp_dir: "pathlib.Path",
) -> None:
try:
import pymatching
except ImportError as ex:
raise ImportError(
"The decoder 'pymatching' isn't installed\n"
"To fix this, install the python package 'pymatching' into your environment.\n"
"For example, if you are using pip, run `pip install pymatching`.\n"
) from ex

# correlated matching requires pymatching 2.3.1 or later
if version.parse(pymatching.__version__) < version.parse("2.3.1"):
raise ValueError("""
The correlated pymatching decoder requires pymatching 2.3.1 or later.

If you're using pip to install packages, this can be fixed by running
```
pip install "pymatching~=2.3.1" --upgrade
```
""")

if num_dets == 0:
with open(obs_predictions_b8_out_path, "wb") as f:
f.write(b"\0" * (num_obs * num_shots))
return

result = pymatching.cli(
command_line_args=[
"predict",
"--dem",
str(dem_path),
"--in",
str(dets_b8_in_path),
"--in_format",
"b8",
"--out",
str(obs_predictions_b8_out_path),
"--out_format",
"b8",
"--enable_correlations",
]
)
if result:
raise ValueError("pymatching.cli returned a non-zero exit code")
13 changes: 9 additions & 4 deletions glue/sample/src/sinter/_decoding/_decoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def get_test_decoders() -> Tuple[List[str], Dict[str, sinter.Decoder]]:
custom_decoders = {}
try:
import pymatching
from packaging import version

if version.parse(pymatching.__version__) < version.parse("2.3.0"):
available_decoders.remove('correlated_pymatching')

except ImportError:
available_decoders.remove('pymatching')
try:
Expand Down Expand Up @@ -234,7 +239,7 @@ def test_no_detectors_with_post_mask(decoder: str, force_streaming: Optional[boo
@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES)
def test_post_selection(decoder: str, force_streaming: Optional[bool]):
circuit = stim.Circuit("""
X_ERROR(0.6) 0
X_ERROR(0.4) 0
M 0
DETECTOR(2, 0, 0, 1) rec[-1]
OBSERVABLE_INCLUDE(0) rec[-1]
Expand All @@ -243,7 +248,7 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]):
M 1
DETECTOR(1, 0, 0) rec[-1]
OBSERVABLE_INCLUDE(0) rec[-1]

X_ERROR(0.1) 2
M 2
OBSERVABLE_INCLUDE(0) rec[-1]
Expand All @@ -259,9 +264,9 @@ def test_post_selection(decoder: str, force_streaming: Optional[bool]):
__private__unstable__force_decode_on_disk=force_streaming,
custom_decoders=TEST_CUSTOM_DECODERS,
)
assert 1050 <= result.discards <= 1350
assert 650 <= result.discards <= 950
if 'vacuous' not in decoder:
assert 40 <= result.errors <= 160
assert 60 <= result.errors <= 240


@pytest.mark.parametrize('decoder,force_streaming', DECODER_CASES)
Expand Down
Loading