@@ -23,6 +23,11 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
2323 parser .add_argument ("--vecs_file" , help = "Vectors *vecs file" , type = Path )
2424 parser .add_argument ("--query_file" , help = "Query vectors file" , type = Path )
2525 parser .add_argument ("--out_file" , help = "Output file" , type = Path )
26+ parser .add_argument (
27+ "--query_out_file" ,
28+ help = "Output file for query vectors generated when num_queries given" ,
29+ type = Path ,
30+ )
2631 parser .add_argument (
2732 "--distance" ,
2833 help = "Distance" ,
@@ -33,6 +38,14 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
3338 "-k" , help = "Number of neighbors" , type = int , default = 100
3439 )
3540 parser .add_argument ("--num_vectors" , help = "Number of vectors" , type = int )
41+ parser .add_argument (
42+ "--num_query_vectors" ,
43+ help = "Number of query vectors."
44+ " If given, query vectors will be shuffled."
45+ " If more than in the query file, the query vectors will be shuffled"
46+ " and repeated as needed." ,
47+ type = int ,
48+ )
3649 parser .add_argument (
3750 "--shuffle" , help = "Shuffle order of vectors" , action = "store_true"
3851 )
@@ -54,8 +67,10 @@ def main(argv: str | None = None) -> None:
5467 k = args .k ,
5568 num_threads = args .max_threads ,
5669 out_file = args .out_file ,
70+ query_out_path = args .query_out_file ,
5771 shuffle = args .shuffle ,
5872 seed = args .seed ,
73+ num_query_vectors = args .num_query_vectors ,
5974 )
6075
6176
@@ -64,31 +79,72 @@ def generate_ground_truth(
6479 vecs_path : Path ,
6580 query_file : Path ,
6681 distance : svs .DistanceType ,
67- num_vectors : int | None ,
82+ num_vectors : int | None = None ,
6883 k : int = 100 ,
6984 num_threads : int = 1 ,
7085 out_file : Path | None = None ,
86+ query_out_path : Path | None = None ,
7187 shuffle : bool = False ,
7288 seed : int = 42 ,
89+ num_query_vectors : int | None = None ,
7390) -> None :
74- if out_file is None :
75- out_file = utils .ground_truth_path (
76- vecs_path , query_file , distance , num_vectors , seed if shuffle else None ,
91+ if out_file is not None and out_file .suffix != ".ivecs" :
92+ raise SystemExit ("Error: --out_file must end in .ivecs" )
93+ if (
94+ query_out_path is not None
95+ and query_out_path .suffix != query_file .suffix
96+ ):
97+ raise SystemExit (
98+ "Error: --query_out_path must have the same suffix as --query_file"
7799 )
78- else :
79- if out_file .suffix != ".ivecs" :
80- raise SystemExit ("Error: --out_file must end in .ivecs" )
81- out_file = str (out_file )
82100 queries = svs .read_vecs (str (query_file ))
83101 vectors = svs .read_vecs (str (vecs_path ))
84- if num_vectors is None :
85- num_vectors = vectors . shape [ 0 ]
102+ # If num_vectors is None or larger than the number of vectors,
103+ # slicing will return the whole array.
86104 vectors = vectors [:num_vectors ]
87105 if shuffle :
88- vectors = vectors [ np .random .default_rng (seed ).permutation ( num_vectors )]
106+ np .random .default_rng (seed ).shuffle ( vectors )
89107 index = svs .Flat (vectors , distance = distance , num_threads = num_threads )
90108 idxs , _ = index .search (queries , k )
91- svs .write_vecs (idxs .astype (np .uint32 ), out_file )
109+ if num_query_vectors is not None :
110+ queries_all = np .empty_like (
111+ queries , shape = (num_query_vectors , queries .shape [1 ])
112+ )
113+ ground_truth_all = np .empty_like (
114+ idxs , shape = (num_query_vectors , idxs .shape [1 ])
115+ )
116+ rng = np .random .default_rng (seed )
117+ cursor = 0
118+ while cursor < num_query_vectors :
119+ permutation = rng .permutation (len (queries ))
120+ batch_size = min (num_query_vectors - cursor , len (queries ))
121+ queries_all [cursor : cursor + batch_size ] = queries [
122+ permutation [:batch_size ]
123+ ]
124+ ground_truth_all [cursor : cursor + batch_size ] = idxs [
125+ permutation [:batch_size ]
126+ ]
127+ cursor += batch_size
128+ if query_out_path is None :
129+ query_out_path = (
130+ query_file .parent
131+ / f"{ query_file .stem } -{ num_query_vectors } _{ seed } "
132+ f"{ query_file .suffix } "
133+ )
134+ svs .write_vecs (queries_all , str (query_out_path ))
135+ queries_path = query_out_path
136+ else :
137+ queries_path = query_file
138+ ground_truth_all = idxs
139+ if out_file is None :
140+ out_file = utils .ground_truth_path (
141+ vecs_path ,
142+ queries_path ,
143+ distance ,
144+ num_vectors ,
145+ seed if shuffle else None ,
146+ )
147+ svs .write_vecs (ground_truth_all .astype (np .uint32 ), str (out_file ))
92148 logger .info ({"ground_truth_saved" : out_file })
93149
94150
0 commit comments