Skip to content

Commit 13be65e

Browse files
committed
Add option to repeat queries in ground truth
1 parent 5386b39 commit 13be65e

File tree

3 files changed

+216
-19
lines changed

3 files changed

+216
-19
lines changed

src/svsbench/generate_ground_truth.py

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
2323
parser.add_argument("--vecs_file", help="Vectors *vecs file", type=Path)
2424
parser.add_argument("--query_file", help="Query vectors file", type=Path)
2525
parser.add_argument("--out_file", help="Output file", type=Path)
26+
parser.add_argument(
27+
"--query_out_file",
28+
help="Output file for query vectors generated when num_queries given",
29+
type=Path,
30+
)
2631
parser.add_argument(
2732
"--distance",
2833
help="Distance",
@@ -33,6 +38,14 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
3338
"-k", help="Number of neighbors", type=int, default=100
3439
)
3540
parser.add_argument("--num_vectors", help="Number of vectors", type=int)
41+
parser.add_argument(
42+
"--num_query_vectors",
43+
help="Number of query vectors."
44+
" If given, query vectors will be shuffled."
45+
" If more than in the query file, the query vectors will be shuffled"
46+
" and repeated as needed.",
47+
type=int,
48+
)
3649
parser.add_argument(
3750
"--shuffle", help="Shuffle order of vectors", action="store_true"
3851
)
@@ -54,8 +67,10 @@ def main(argv: str | None = None) -> None:
5467
k=args.k,
5568
num_threads=args.max_threads,
5669
out_file=args.out_file,
70+
query_out_path=args.query_out_file,
5771
shuffle=args.shuffle,
5872
seed=args.seed,
73+
num_query_vectors=args.num_query_vectors,
5974
)
6075

6176

@@ -64,31 +79,72 @@ def generate_ground_truth(
6479
vecs_path: Path,
6580
query_file: Path,
6681
distance: svs.DistanceType,
67-
num_vectors: int | None,
82+
num_vectors: int | None = None,
6883
k: int = 100,
6984
num_threads: int = 1,
7085
out_file: Path | None = None,
86+
query_out_path: Path | None = None,
7187
shuffle: bool = False,
7288
seed: int = 42,
89+
num_query_vectors: int | None = None,
7390
) -> None:
74-
if out_file is None:
75-
out_file = utils.ground_truth_path(
76-
vecs_path, query_file, distance, num_vectors, seed if shuffle else None,
91+
if out_file is not None and out_file.suffix != ".ivecs":
92+
raise SystemExit("Error: --out_file must end in .ivecs")
93+
if (
94+
query_out_path is not None
95+
and query_out_path.suffix != query_file.suffix
96+
):
97+
raise SystemExit(
98+
"Error: --query_out_path must have the same suffix as --query_file"
7799
)
78-
else:
79-
if out_file.suffix != ".ivecs":
80-
raise SystemExit("Error: --out_file must end in .ivecs")
81-
out_file = str(out_file)
82100
queries = svs.read_vecs(str(query_file))
83101
vectors = svs.read_vecs(str(vecs_path))
84-
if num_vectors is None:
85-
num_vectors = vectors.shape[0]
102+
# If num_vectors is None or larger than the number of vectors,
103+
# slicing will return the whole array.
86104
vectors = vectors[:num_vectors]
87105
if shuffle:
88-
vectors = vectors[np.random.default_rng(seed).permutation(num_vectors)]
106+
np.random.default_rng(seed).shuffle(vectors)
89107
index = svs.Flat(vectors, distance=distance, num_threads=num_threads)
90108
idxs, _ = index.search(queries, k)
91-
svs.write_vecs(idxs.astype(np.uint32), out_file)
109+
if num_query_vectors is not None:
110+
queries_all = np.empty_like(
111+
queries, shape=(num_query_vectors, queries.shape[1])
112+
)
113+
ground_truth_all = np.empty_like(
114+
idxs, shape=(num_query_vectors, idxs.shape[1])
115+
)
116+
rng = np.random.default_rng(seed)
117+
cursor = 0
118+
while cursor < num_query_vectors:
119+
permutation = rng.permutation(len(queries))
120+
batch_size = min(num_query_vectors - cursor, len(queries))
121+
queries_all[cursor : cursor + batch_size] = queries[
122+
permutation[:batch_size]
123+
]
124+
ground_truth_all[cursor : cursor + batch_size] = idxs[
125+
permutation[:batch_size]
126+
]
127+
cursor += batch_size
128+
if query_out_path is None:
129+
query_out_path = (
130+
query_file.parent
131+
/ f"{query_file.stem}-{num_query_vectors}_{seed}"
132+
f"{query_file.suffix}"
133+
)
134+
svs.write_vecs(queries_all, str(query_out_path))
135+
queries_path = query_out_path
136+
else:
137+
queries_path = query_file
138+
ground_truth_all = idxs
139+
if out_file is None:
140+
out_file = utils.ground_truth_path(
141+
vecs_path,
142+
queries_path,
143+
distance,
144+
num_vectors,
145+
seed if shuffle else None,
146+
)
147+
svs.write_vecs(ground_truth_all.astype(np.uint32), str(out_file))
92148
logger.info({"ground_truth_saved": out_file})
93149

94150

tests/conftest.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,23 @@
1515
8: svs.LeanVecKind.lvq8,
1616
}
1717

18+
RANDOM_VECTORS_SHAPE: Final = (1000, 100)
19+
NUM_RANDOM_QUERY_VECTORS: Final = 100
20+
GROUND_TRUTH_K: Final = 100
21+
1822

1923
def random_array(dtype: np.dtype) -> np.ndarray:
2024
rng = np.random.default_rng(42)
2125
if np.dtype(dtype).kind == "i":
2226
iinfo = np.iinfo(dtype)
23-
return rng.integers(iinfo.min, iinfo.max, (1000, 100), dtype=dtype)
27+
return rng.integers(
28+
iinfo.min, iinfo.max, RANDOM_VECTORS_SHAPE, dtype=dtype
29+
)
2430
else:
25-
return rng.random((1000, 100)).astype(dtype)
31+
return rng.random(RANDOM_VECTORS_SHAPE).astype(dtype)
2632

2733

28-
@pytest.fixture(
29-
scope="session", params=consts.SUFFIX_TO_SVS_TYPE.keys()
30-
)
34+
@pytest.fixture(scope="session", params=consts.SUFFIX_TO_SVS_TYPE.keys())
3135
def tmp_vecs(request, tmp_path_factory):
3236
suffix = request.param
3337
vecs_path = tmp_path_factory.mktemp("vecs") / ("random" + suffix)
@@ -134,7 +138,10 @@ def index_dir_with_svs_type_and_dynamic(request, tmp_path_factory):
134138
def query_path(tmp_path_factory) -> Path:
135139
path = tmp_path_factory.mktemp("query") / "query.fvecs"
136140
svs.write_vecs(
137-
np.random.default_rng(42).random((100, 100)).astype(np.float32), path
141+
np.random.default_rng(42)
142+
.random((NUM_RANDOM_QUERY_VECTORS, RANDOM_VECTORS_SHAPE[1]))
143+
.astype(np.float32),
144+
path,
138145
)
139146
return path
140147

@@ -162,7 +169,7 @@ def ground_truth_path(
162169
)
163170
vectors = np.load(index_dir / "data.npy")
164171
index = svs.Flat(vectors, distance=distance, num_threads=num_threads)
165-
idxs, _ = index.search(svs.read_vecs(str(query_path)), 100)
172+
idxs, _ = index.search(svs.read_vecs(str(query_path)), GROUND_TRUTH_K)
166173
ground_truth_path = (
167174
tmp_path_factory.mktemp("ground_truth")
168175
/ f"ground_truth_{index_svs_type}.ivecs"
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (C) 2025 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
import functools
4+
5+
import conftest
6+
import pytest
7+
import svs
8+
9+
from svsbench.generate_ground_truth import generate_ground_truth, main
10+
11+
12+
def test_generate_ground_truth_no_shuffle(
13+
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
14+
):
15+
if tmp_vecs.suffix == ".hvecs":
16+
pytest.xfail("Not implemented")
17+
out_file = tmp_path_factory.mktemp("output") / "ground_truth.ivecs"
18+
k = 10
19+
generate_ground_truth(
20+
vecs_path=tmp_vecs,
21+
query_file=query_path,
22+
distance=distance,
23+
num_vectors=None,
24+
k=k,
25+
num_threads=num_threads,
26+
out_file=out_file,
27+
query_out_path=None,
28+
shuffle=False,
29+
seed=42,
30+
)
31+
assert out_file.is_file()
32+
gt = svs.read_vecs(str(out_file))
33+
assert gt.shape == (conftest.NUM_RANDOM_QUERY_VECTORS, k), (
34+
"Expected (num_queries, k) shape"
35+
)
36+
37+
38+
def test_generate_ground_truth_shuffle(
39+
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
40+
):
41+
if tmp_vecs.suffix == ".hvecs":
42+
pytest.xfail("Not implemented")
43+
out_file = (
44+
tmp_path_factory.mktemp("output") / "ground_truth_shuffled.ivecs"
45+
)
46+
k = 5
47+
generate_ground_truth(
48+
vecs_path=tmp_vecs,
49+
query_file=query_path,
50+
distance=distance,
51+
num_vectors=500,
52+
k=k,
53+
num_threads=num_threads,
54+
out_file=out_file,
55+
shuffle=True,
56+
seed=2,
57+
)
58+
assert out_file.is_file()
59+
gt = svs.read_vecs(str(out_file))
60+
assert gt.shape == (conftest.NUM_RANDOM_QUERY_VECTORS, k)
61+
62+
63+
def test_generate_ground_truth_num_query_vectors(
64+
tmp_vecs, query_path, distance, num_threads, tmp_path_factory
65+
):
66+
k = 7
67+
if tmp_vecs.suffix == ".hvecs":
68+
pytest.xfail("Not implemented")
69+
out_dir = tmp_path_factory.mktemp("output")
70+
out_file = out_dir / "ground_truth_subqueries.ivecs"
71+
query_out_path = out_dir / "queries_out.fvecs"
72+
generate_ground_truth_partial = functools.partial(
73+
generate_ground_truth,
74+
vecs_path=tmp_vecs,
75+
query_file=query_path,
76+
distance=distance,
77+
k=k,
78+
num_threads=num_threads,
79+
out_file=out_file,
80+
query_out_path=query_out_path,
81+
shuffle=True,
82+
)
83+
for num_query_vectors in [20, 200]:
84+
generate_ground_truth_partial(num_query_vectors=num_query_vectors)
85+
gt = svs.read_vecs(str(out_file))
86+
new_queries = svs.read_vecs(str(query_out_path))
87+
assert gt.shape == (num_query_vectors, k)
88+
assert new_queries.shape == (
89+
num_query_vectors,
90+
conftest.RANDOM_VECTORS_SHAPE[1],
91+
)
92+
93+
94+
def test_generate_ground_truth_main(
95+
tmp_vecs, query_path, num_threads, tmp_path_factory
96+
):
97+
if tmp_vecs.suffix == ".hvecs":
98+
pytest.xfail("Not implemented")
99+
out_dir = tmp_path_factory.mktemp("cli")
100+
out_file = out_dir / "gt.ivecs"
101+
query_out_path = out_dir / "queries_out.fvecs"
102+
k = 8
103+
num_query_vectors = 150
104+
105+
argv = [
106+
"--vecs_file",
107+
str(tmp_vecs),
108+
"--query_file",
109+
str(query_path),
110+
"--out_file",
111+
str(out_file),
112+
"--distance",
113+
"mip",
114+
"-k",
115+
str(k),
116+
"--max_threads",
117+
str(num_threads),
118+
"--seed",
119+
"42",
120+
"--num_query_vectors",
121+
str(num_query_vectors),
122+
"--query_out_file",
123+
str(query_out_path),
124+
]
125+
126+
main(argv)
127+
128+
gt = svs.read_vecs(str(out_file))
129+
new_queries = svs.read_vecs(str(query_out_path))
130+
assert gt.shape == (num_query_vectors, k)
131+
assert new_queries.shape == (
132+
num_query_vectors,
133+
conftest.RANDOM_VECTORS_SHAPE[1],
134+
)

0 commit comments

Comments
 (0)