77import sys
88import time
99from pathlib import Path
10- from typing import Final
1110
1211import numpy as np
1312import svs
1615from . import consts , utils
1716from .loader import create_loader
1817
19- STR_TO_STRATEGY : Final [dict [str , svs .LVQStrategy ]] = {
20- "auto" : svs .LVQStrategy .Auto ,
21- "sequential" : svs .LVQStrategy .Sequential ,
22- "turbo" : svs .LVQStrategy .Turbo ,
23- }
24-
25-
2618logger = logging .getLogger (__file__ )
2719
2820
@@ -38,7 +30,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
3830 help = "Query type" ,
3931 choices = consts .STR_TO_DATA_TYPE .keys (),
4032 default = "float32" ,
41- type = consts .STR_TO_DATA_TYPE .get ,
4233 )
4334 parser .add_argument ("--idx_dir" , help = "Index dir" , type = Path )
4435 parser .add_argument ("--data_dir" , help = "Data dir" , type = Path )
@@ -58,11 +49,10 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
5849 type = Path ,
5950 )
6051 parser .add_argument (
61- "--strategy " ,
52+ "--lvq_strategy " ,
6253 help = "LVQ strategy" ,
63- choices = tuple (STR_TO_STRATEGY .keys ()),
54+ choices = tuple (consts . STR_TO_LVQ_STRATEGY .keys ()),
6455 default = "auto" ,
65- type = STR_TO_STRATEGY .get ,
6656 )
6757 parser .add_argument (
6858 "--leanvec_dims" , help = "LeanVec dimensionality" , default = - 4 , type = int
@@ -115,7 +105,6 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
115105 "--distance" ,
116106 choices = tuple (consts .STR_TO_DISTANCE .keys ()),
117107 default = "mip" ,
118- type = consts .STR_TO_DISTANCE .get ,
119108 )
120109 parser .add_argument (
121110 "--load_from_static" ,
@@ -151,6 +140,7 @@ def search(
151140 calibration_query_path : Path | None = None ,
152141 calibration_ground_truth_path : Path | None = None ,
153142 load_from_static : bool = False ,
143+ lvq_strategy : svs .LVQStrategy | None = None ,
154144):
155145 logger .info ({"search_args" : locals ()})
156146 logger .info (utils .read_system_config ())
@@ -178,6 +168,7 @@ def search(
178168 compress = compress ,
179169 leanvec_dims = leanvec_dims ,
180170 leanvec_alignment = leanvec_alignment ,
171+ lvq_strategy = lvq_strategy ,
181172 )
182173
183174 if static :
@@ -337,11 +328,12 @@ def main(argv: str | None = None) -> None:
337328 prefetch_steps = args .prefetch_step ,
338329 num_rep = args .num_rep ,
339330 static = args .static ,
340- distance = args .distance ,
341- query_type = args .query_type ,
331+ distance = consts . STR_TO_DISTANCE [ args .distance ] ,
332+ query_type = consts . STR_TO_DATA_TYPE [ args .query_type ] ,
342333 calibration_query_path = args .calibration_query_file ,
343334 calibration_ground_truth_path = args .calibration_ground_truth_file ,
344335 load_from_static = args .load_from_static ,
336+ lvq_strategy = consts .STR_TO_LVQ_STRATEGY .get (args .lvq_strategy , None ),
345337 )
346338
347339
0 commit comments