Skip to content

Commit dab6b68

Browse files
Merge pull request #2 from mgorny/index-json
[WheelVariants] Implement getting sorted variants aided by `variants.json`
2 parents d89351b + b44d59b commit dab6b68

File tree

5 files changed

+204
-5
lines changed

5 files changed

+204
-5
lines changed

.github/workflows/ci.yml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
name: CI
2+
on: [push, pull_request]
3+
jobs:
4+
test:
5+
strategy:
6+
matrix:
7+
python-version: ["3.10", "3.11", "3.12", "3.13", "pypy-3.10", "pypy-3.11"]
8+
fail-fast: false
9+
runs-on: ubuntu-latest
10+
steps:
11+
- name: Checkout
12+
uses: actions/checkout@v4
13+
- name: Set up Python
14+
uses: actions/setup-python@v5
15+
with:
16+
python-version: ${{ matrix.python-version }}
17+
- name: Install uv
18+
uses: astral-sh/setup-uv@v5
19+
with:
20+
python-version: ${{ matrix.python-version }}
21+
- name: Install the package and test dependencies
22+
run: |
23+
uv venv
24+
uv pip install -e '.[test]'
25+
- name: Run tests
26+
run: uv run --no-project pytest

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ dev = [
4040
]
4141
test = [
4242
"jsondiff>=2.2,<2.3",
43+
"hypothesis>=6.0.0,<7",
4344
"pytest>=8.0.0,<9.0.0",
4445
"pytest-cov>=5.0.0,<6.0.0",
4546
"pytest-dotenv>=0.5.0,<1.0.0",

tests/test_combinations.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
import json
22
from pathlib import Path
3+
import random
4+
import string
35

6+
from hypothesis import assume
7+
from hypothesis import example
8+
from hypothesis import given
9+
from hypothesis import strategies as st
410
import jsondiff
11+
import pytest
12+
from variantlib.combination import filtered_sorted_variants
513
from variantlib.combination import get_combinations
614
from variantlib.config import KeyConfig
715
from variantlib.config import ProviderConfig
16+
from variantlib.meta import VariantDescription
817

918

10-
def test_get_combinations():
11-
"""Test `get_combinations` yields the expected result in the right order."""
19+
@pytest.fixture(scope="session")
20+
def configs():
1221
config_custom_hw = ProviderConfig(
1322
provider="custom_hw",
1423
configs=[
@@ -24,8 +33,11 @@ def test_get_combinations():
2433
],
2534
)
2635

27-
configs = [config_custom_hw, config_networking]
36+
return [config_custom_hw, config_networking]
37+
2838

39+
def test_get_combinations(configs):
40+
"""Test `get_combinations` yields the expected result in the right order."""
2941
result = [vdesc.serialize() for vdesc in get_combinations(configs)]
3042

3143
json_file = Path("tests/artifacts/expected.json")
@@ -37,3 +49,81 @@ def test_get_combinations():
3749

3850
differences = jsondiff.diff(result, expected)
3951
assert not differences, f"Serialization altered JSON: {differences}"
52+
53+
54+
def desc_to_json(desc_list: list[VariantDescription]) -> dict:
55+
shuffled_desc_list = list(desc_list)
56+
random.shuffle(shuffled_desc_list)
57+
for desc in shuffled_desc_list:
58+
variant_dict = {}
59+
for variant_meta in desc:
60+
provider_dict = variant_dict.setdefault(variant_meta.provider, {})
61+
provider_dict[variant_meta.key] = variant_meta.value
62+
yield (desc.hexdigest, variant_dict)
63+
64+
65+
def test_filtered_sorted_variants_roundtrip(configs):
66+
"""Test that we can round-trip all combinations via variants.json and get the same result."""
67+
combinations = list(get_combinations(configs))
68+
variants_from_json = {k: v for k, v in desc_to_json(combinations)}
69+
assert filtered_sorted_variants(variants_from_json, configs) == combinations
70+
71+
72+
@example(
73+
[
74+
ProviderConfig(
75+
provider="A",
76+
configs=[
77+
KeyConfig(key="A1", values=["x"]),
78+
KeyConfig(key="A2", values=["x"]),
79+
],
80+
),
81+
ProviderConfig(provider="B", configs=[KeyConfig(key="B1", values=["x"])]),
82+
ProviderConfig(provider="C", configs=[KeyConfig(key="C1", values=["x"])]),
83+
]
84+
)
85+
@given(
86+
st.lists(
87+
min_size=1,
88+
max_size=3,
89+
unique_by=lambda provider_cfg: provider_cfg.provider,
90+
elements=st.builds(
91+
ProviderConfig,
92+
provider=st.text(
93+
string.ascii_letters + string.digits + "_", min_size=1, max_size=64
94+
),
95+
configs=st.lists(
96+
min_size=1,
97+
max_size=2,
98+
unique_by=lambda key_cfg: key_cfg.key,
99+
elements=st.builds(
100+
KeyConfig,
101+
key=st.text(
102+
alphabet=string.ascii_letters + string.digits + "_",
103+
min_size=1,
104+
max_size=64,
105+
),
106+
values=st.lists(
107+
min_size=1,
108+
max_size=3,
109+
unique=True,
110+
elements=st.text(
111+
alphabet=string.ascii_letters + string.digits + "_.",
112+
min_size=1,
113+
max_size=64,
114+
),
115+
),
116+
),
117+
),
118+
),
119+
)
120+
)
121+
def test_filtered_sorted_variants_roundtrip_fuzz(configs):
122+
def filter_long_combinations():
123+
for i, x in enumerate(get_combinations(configs)):
124+
assume(i < 65536)
125+
yield x
126+
127+
combinations = list(filter_long_combinations())
128+
variants_from_json = {k: v for k, v in desc_to_json(combinations)}
129+
assert filtered_sorted_variants(variants_from_json, configs) == combinations

variantlib/combination.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import itertools
2+
import logging
23
from collections.abc import Generator
34

45
from variantlib.config import ProviderConfig
56
from variantlib.meta import VariantDescription
67
from variantlib.meta import VariantMeta
78

9+
logger = logging.getLogger(__name__)
10+
811

912
def get_combinations(data: list[ProviderConfig]) -> Generator[VariantDescription]:
1013
"""Generate all possible combinations of `VariantMeta` given a list of
@@ -30,6 +33,79 @@ def get_combinations(data: list[ProviderConfig]) -> Generator[VariantDescription
3033
yield VariantDescription(data=vmetas)
3134

3235

36+
def unpack_variants_from_json(variants_from_json: dict
37+
) -> Generator[VariantDescription]:
38+
def variant_to_metas(providers: dict) -> VariantMeta:
39+
for provider, keys in providers.items():
40+
for key, value in keys.items():
41+
yield VariantMeta(provider=provider,
42+
key=key,
43+
value=value)
44+
45+
for variant_hash, providers in variants_from_json.items():
46+
desc = VariantDescription(variant_to_metas(providers))
47+
assert variant_hash == desc.hexdigest
48+
yield desc
49+
50+
51+
def filtered_sorted_variants(variants_from_json: dict,
52+
data: list[ProviderConfig]
53+
) -> Generator[VariantDescription]:
54+
providers = {}
55+
for provider_idx, provider_cnf in enumerate(data):
56+
keys = {}
57+
for key_idx, key_cnf in enumerate(provider_cnf.configs):
58+
keys[key_cnf.key] = key_idx, key_cnf.values
59+
providers[provider_cnf.provider] = provider_idx, keys
60+
61+
missing_providers = set()
62+
missing_keys = {}
63+
64+
def variant_filter(desc: VariantDescription):
65+
# Filter out the variant, unless all of its metas are supported.
66+
for meta in desc:
67+
if (provider_data := providers.get(meta.provider)) is None:
68+
missing_providers.add(meta.provider)
69+
return False
70+
_, keys = provider_data
71+
if (key_data := keys.get(meta.key)) is None:
72+
missing_keys.setdefault(meta.provider, set()).add(meta.key)
73+
return False
74+
_, values = key_data
75+
if meta.value not in values:
76+
return False
77+
return True
78+
79+
def meta_key(meta: VariantMeta) -> tuple[int, int, int]:
80+
# The sort key is a tuple of (provider, key, value) indices, so that
81+
# the metas with more preferred (provider, key, value) sort first.
82+
provider_idx, keys = providers.get(meta.provider)
83+
key_idx, values = keys.get(meta.key)
84+
value_idx = values.index(meta.value)
85+
return provider_idx, key_idx, value_idx
86+
87+
def variant_sort_key_gen(desc: VariantDescription) -> Generator[tuple]:
88+
# Variants with more matched values should go first.
89+
yield -len(desc.data)
90+
# Sort meta sort keys by their sort keys, so that metas containing
91+
# more preferred sort key sort first.
92+
meta_keys = sorted(meta_key(x) for x in desc.data)
93+
# Always prefer all values from the "stronger" keys over "weaker".
94+
yield from (x[0:2] for x in meta_keys)
95+
yield from (x[2] for x in meta_keys)
96+
97+
res = sorted(filter(variant_filter,
98+
unpack_variants_from_json(variants_from_json)),
99+
key=lambda x: tuple(variant_sort_key_gen(x)))
100+
if missing_providers:
101+
logger.warn("No plugins provide the following variant providers: "
102+
f"{' '.join(missing_providers)}; some variants will be ignored")
103+
for provider, provider_missing_keys in missing_keys.items():
104+
logger.warn(f"The {provider} provider does not provide the following expected keys: "
105+
f"{' '.join(provider_missing_keys)}; some variants will be ignored")
106+
return res
107+
108+
33109
if __name__ == "__main__": # pragma: no cover
34110
import json
35111
from pathlib import Path

variantlib/platform.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from importlib.metadata import entry_points
77
from typing import TYPE_CHECKING
88

9+
from variantlib.combination import filtered_sorted_variants
910
from variantlib.combination import get_combinations
1011
from variantlib.config import ProviderConfig
1112

@@ -82,6 +83,7 @@ def _query_variant_plugins() -> dict[str, ProviderConfig]:
8283

8384
def get_variant_hashes_by_priority(
8485
provider_priority_dict: dict[str:int] | None = None,
86+
variants_json: dict | None = None,
8587
) -> Generator[VariantDescription]:
8688
plugins = entry_points().select(group="variantlib.plugins")
8789

@@ -134,7 +136,11 @@ def get_variant_hashes_by_priority(
134136
sorted_provider_cfgs = [provider_cfgs[plugin.name] for plugin in plugins]
135137

136138
if sorted_provider_cfgs:
137-
for variant_desc in get_combinations(sorted_provider_cfgs):
138-
yield variant_desc.hexdigest
139+
if (variants_json or {}).get("variants") is not None:
140+
for variant_desc in filtered_sorted_variants(variants_json["variants"], sorted_provider_cfgs):
141+
yield variant_desc.hexdigest
142+
else:
143+
for variant_desc in get_combinations(sorted_provider_cfgs):
144+
yield variant_desc.hexdigest
139145
else:
140146
yield from []

0 commit comments

Comments
 (0)