Skip to content

Commit 9236d1a

Browse files
authored
Add calibration options (#15)
1 parent e23d182 commit 9236d1a

File tree

1 file changed

+36
-3
lines changed

1 file changed

+36
-3
lines changed

src/svsbench/search.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import sys
88
import time
99
from pathlib import Path
10+
from typing import Final
1011

1112
import numpy as np
1213
import svs
@@ -17,6 +18,15 @@
1718

1819
logger = logging.getLogger(__file__)
1920

21+
STR_TO_CALIBRATE_SEARCH_BUFFER: Final[
22+
dict[str, svs.VamanaSearchBufferOptimization]
23+
] = {
24+
"disable": svs.VamanaSearchBufferOptimization.Disable,
25+
"all": svs.VamanaSearchBufferOptimization.All,
26+
"roionly": svs.VamanaSearchBufferOptimization.ROIOnly,
27+
"roituneup": svs.VamanaSearchBufferOptimization.ROITuneUp,
28+
}
29+
2030

2131
def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
2232
"""Read command line arguments."""
@@ -111,6 +121,12 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace:
111121
action="store_true",
112122
help="Load from static index",
113123
)
124+
parser.add_argument("--no_calibrate_prefetchers", action="store_true")
125+
parser.add_argument(
126+
"--calibrate_search_buffer",
127+
choices=STR_TO_CALIBRATE_SEARCH_BUFFER.keys(),
128+
default="all",
129+
)
114130
return parser.parse_args(argv)
115131

116132

@@ -141,6 +157,8 @@ def search(
141157
calibration_ground_truth_path: Path | None = None,
142158
load_from_static: bool = False,
143159
lvq_strategy: svs.LVQStrategy | None = None,
160+
train_prefetchers: bool = True,
161+
search_buffer_optimization: svs.VamanaSearchBufferOptimization = svs.VamanaSearchBufferOptimization.All,
144162
):
145163
logger.info({"search_args": locals()})
146164
logger.info(utils.read_system_config())
@@ -207,8 +225,17 @@ def search(
207225
else:
208226
calibration_query = query
209227
calibration_ground_truth = ground_truth
228+
calibration_parameters = svs.VamanaCalibrationParameters()
229+
calibration_parameters.search_buffer_optimization = (
230+
search_buffer_optimization
231+
)
232+
calibration_parameters.train_prefetchers = train_prefetchers
210233
index.experimental_calibrate(
211-
calibration_query, calibration_ground_truth, count, recall
234+
calibration_query,
235+
calibration_ground_truth,
236+
count,
237+
recall,
238+
calibration_parameters,
212239
)
213240
logger.info(
214241
{
@@ -285,10 +312,12 @@ def search(
285312
"search_results": {
286313
"qps": qps,
287314
"qps_mean": np.mean(qps),
288-
"qps_rsd": np.std(qps, ddof=min(1, num_rep - 1)) / np.mean(qps),
315+
"qps_rsd": np.std(qps, ddof=min(1, num_rep - 1))
316+
/ np.mean(qps),
289317
"p95": p95s,
290318
"p95_mean": np.mean(p95s),
291-
"p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1)) / np.mean(p95s),
319+
"p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1))
320+
/ np.mean(p95s),
292321
"search_parameters": {
293322
"search_window_size": index.search_parameters.buffer_config.search_window_size,
294323
"search_buffer_capacity": index.search_parameters.buffer_config.search_buffer_capacity,
@@ -338,6 +367,10 @@ def main(argv: str | None = None) -> None:
338367
calibration_ground_truth_path=args.calibration_ground_truth_file,
339368
load_from_static=args.load_from_static,
340369
lvq_strategy=consts.STR_TO_LVQ_STRATEGY.get(args.lvq_strategy, None),
370+
train_prefetchers=not args.no_calibrate_prefetchers,
371+
search_buffer_optimization=STR_TO_CALIBRATE_SEARCH_BUFFER[
372+
args.calibrate_search_buffer
373+
],
341374
)
342375

343376

0 commit comments

Comments
 (0)