44
55import argparse
66import logging
7- import os
87import sys
98import tempfile
109import time
1110from pathlib import Path
1211
1312import numpy as np
13+ import numpy .typing as npt
1414import svs
1515from tqdm import tqdm
1616
17- from . import consts
17+ from . import consts , utils
18+ from .generate_leanvec_matrices import (
19+ generate_leanvec_matrices ,
20+ save_leanvec_matrices ,
21+ )
1822from .loader import create_loader
19- from . import utils
2023
2124logger = logging .getLogger (__file__ )
2225
@@ -86,18 +89,83 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
8689 parser .add_argument (
8790 "--leanvec_dims" , help = "LeanVec dimensionality" , type = int
8891 )
89- parser .add_argument ("--no_save" , action = "store_true" )
92+ parser .add_argument (
93+ "--no_save" , action = "store_true" , help = "Do not save built index"
94+ )
95+ parser .add_argument (
96+ "--train_query_file" ,
97+ help = "Query *vecs file for LeanVec out-of-distribution training" ,
98+ type = Path ,
99+ )
100+ parser .add_argument (
101+ "--train_max_vectors" ,
102+ help = "Maximum number of base vectors from vecs_file"
103+ " to use for LeanVec out-of-distribution training (0 for all)" ,
104+ type = int ,
105+ default = consts .DEFAULT_LEANVEC_TRAIN_MAX_VECTORS ,
106+ )
107+ parser .add_argument (
108+ "--no_save_matrices" ,
109+ action = "store_true" ,
110+ help = "Do not save LeanVec matrices" ,
111+ )
112+ parser .add_argument (
113+ "--data_matrix_file" ,
114+ help = "Data matrix npy file for LeanVec" ,
115+ type = Path ,
116+ )
117+ parser .add_argument (
118+ "--query_matrix_file" ,
119+ help = "Query matrix npy file for LeanVec" ,
120+ type = Path ,
121+ )
90122 return parser .parse_args (argv )
91123
92124
93- def main (argv : str | None = None ) -> None :
125+ def main (argv : list [ str ] | None = None ) -> None :
94126 args = _read_args (argv )
95127 log_file = utils .configure_logger (
96128 logger , args .log_dir if args .log_dir is not None else args .out_dir
97129 )
98130 print ("Logging to" , log_file , sep = "\n " )
99131 logger .info ({"argv" : argv if argv else sys .argv })
100132 args .out_dir .mkdir (exist_ok = True )
133+ if args .data_matrix_file is not None :
134+ if args .query_matrix_file is None :
135+ raise ValueError (
136+ "query_matrix_file must be provided with data_matrix_file"
137+ )
138+ data_matrix = np .load (args .data_matrix_file )
139+ query_matrix = np .load (args .query_matrix_file )
140+ elif args .train_query_file is not None :
141+ (data_matrix , query_matrix ), (leanvec_dims_effective , _ ) = (
142+ generate_leanvec_matrices (
143+ args .vecs_file ,
144+ args .train_query_file ,
145+ args .train_max_vectors ,
146+ args .leanvec_dims ,
147+ )
148+ )
149+ if not args .no_save_matrices :
150+ data_matrix_path , query_matrix_path = save_leanvec_matrices (
151+ args .vecs_file ,
152+ args .train_query_file ,
153+ args .train_max_vectors ,
154+ leanvec_dims_effective ,
155+ data_matrix ,
156+ query_matrix ,
157+ args .out_dir ,
158+ )
159+ logger .info (
160+ {
161+ "saved_leanvec_matrices" : (
162+ data_matrix_path ,
163+ query_matrix_path ,
164+ )
165+ }
166+ )
167+ else :
168+ data_matrix = query_matrix = None
101169 if args .static :
102170 index , name = build_static (
103171 vecs_path = args .vecs_file ,
@@ -110,6 +178,8 @@ def main(argv: str | None = None) -> None:
110178 alpha = args .alpha ,
111179 max_threads = args .max_threads ,
112180 leanvec_dims = args .leanvec_dims ,
181+ data_matrix = data_matrix ,
182+ query_matrix = query_matrix ,
113183 )
114184 else :
115185 index , name , ingest_time , delete_time = build_dynamic (
@@ -135,6 +205,8 @@ def main(argv: str | None = None) -> None:
135205 convert_vecs = args .convert_vecs ,
136206 tmp_dir = args .tmp_dir ,
137207 leanvec_dims = args .leanvec_dims ,
208+ data_matrix = data_matrix ,
209+ query_matrix = query_matrix ,
138210 )
139211 np .save (args .out_dir / (name + ".ingest.npy" ), ingest_time )
140212 if args .num_vectors_delete > 0 :
@@ -167,6 +239,8 @@ def build_dynamic(
167239 convert_vecs : bool = False ,
168240 tmp_dir : Path = Path ("/dev/shm" ),
169241 leanvec_dims : int | None = None ,
242+ data_matrix : npt .NDArray | None = None ,
243+ query_matrix : npt .NDArray | None = None ,
170244) -> tuple [svs .DynamicVamana , str ]:
171245 """Build SVS index."""
172246 logger .info ({"build_args" : locals ()})
@@ -264,6 +338,8 @@ def build_dynamic(
264338 data_dir = tmp_idx_dir / "data" ,
265339 compress = not svs_type .startswith ("float" ),
266340 leanvec_dims = leanvec_dims ,
341+ data_matrix = data_matrix ,
342+ query_matrix = query_matrix ,
267343 )
268344 index = svs .DynamicVamana (
269345 str (tmp_idx_dir / "config" ),
@@ -343,6 +419,8 @@ def build_static(
343419 alpha : float | None = None ,
344420 max_threads : int = 1 ,
345421 leanvec_dims : int | None = None ,
422+ data_matrix : npt .NDArray | None = None ,
423+ query_matrix : npt .NDArray | None = None ,
346424) -> tuple [svs .Vamana , str ]:
347425 logger .info ({"build_args" : locals ()})
348426 logger .info (utils .read_system_config ())
@@ -360,7 +438,11 @@ def build_static(
360438 index = svs .Vamana .build (
361439 parameters ,
362440 create_loader (
363- svs_type , vecs_path = vecs_path , leanvec_dims = leanvec_dims
441+ svs_type ,
442+ vecs_path = vecs_path ,
443+ leanvec_dims = leanvec_dims ,
444+ data_matrix = data_matrix ,
445+ query_matrix = query_matrix ,
364446 ),
365447 distance ,
366448 num_threads = max_threads ,
0 commit comments