|
7 | 7 | import sys |
8 | 8 | import time |
9 | 9 | from pathlib import Path |
| 10 | +from typing import Final |
10 | 11 |
|
11 | 12 | import numpy as np |
12 | 13 | import svs |
|
17 | 18 |
|
18 | 19 | logger = logging.getLogger(__file__) |
19 | 20 |
|
| 21 | +STR_TO_CALIBRATE_SEARCH_BUFFER: Final[ |
| 22 | + dict[str, svs.VamanaSearchBufferOptimization] |
| 23 | +] = { |
| 24 | + "disable": svs.VamanaSearchBufferOptimization.Disable, |
| 25 | + "all": svs.VamanaSearchBufferOptimization.All, |
| 26 | + "roionly": svs.VamanaSearchBufferOptimization.ROIOnly, |
| 27 | + "roituneup": svs.VamanaSearchBufferOptimization.ROITuneUp, |
| 28 | +} |
| 29 | + |
20 | 30 |
|
21 | 31 | def _read_args(argv: list[str] | None = None) -> argparse.Namespace: |
22 | 32 | """Read command line arguments.""" |
@@ -111,6 +121,12 @@ def _read_args(argv: list[str] | None = None) -> argparse.Namespace: |
111 | 121 | action="store_true", |
112 | 122 | help="Load from static index", |
113 | 123 | ) |
| 124 | + parser.add_argument("--no_calibrate_prefetchers", action="store_true") |
| 125 | + parser.add_argument( |
| 126 | + "--calibrate_search_buffer", |
| 127 | + choices=STR_TO_CALIBRATE_SEARCH_BUFFER.keys(), |
| 128 | + default="all", |
| 129 | + ) |
114 | 130 | return parser.parse_args(argv) |
115 | 131 |
|
116 | 132 |
|
@@ -141,6 +157,8 @@ def search( |
141 | 157 | calibration_ground_truth_path: Path | None = None, |
142 | 158 | load_from_static: bool = False, |
143 | 159 | lvq_strategy: svs.LVQStrategy | None = None, |
| 160 | + train_prefetchers: bool = True, |
| 161 | + search_buffer_optimization: svs.VamanaSearchBufferOptimization = svs.VamanaSearchBufferOptimization.All, |
144 | 162 | ): |
145 | 163 | logger.info({"search_args": locals()}) |
146 | 164 | logger.info(utils.read_system_config()) |
@@ -207,8 +225,17 @@ def search( |
207 | 225 | else: |
208 | 226 | calibration_query = query |
209 | 227 | calibration_ground_truth = ground_truth |
| 228 | + calibration_parameters = svs.VamanaCalibrationParameters() |
| 229 | + calibration_parameters.search_buffer_optimization = ( |
| 230 | + search_buffer_optimization |
| 231 | + ) |
| 232 | + calibration_parameters.train_prefetchers = train_prefetchers |
210 | 233 | index.experimental_calibrate( |
211 | | - calibration_query, calibration_ground_truth, count, recall |
| 234 | + calibration_query, |
| 235 | + calibration_ground_truth, |
| 236 | + count, |
| 237 | + recall, |
| 238 | + calibration_parameters, |
212 | 239 | ) |
213 | 240 | logger.info( |
214 | 241 | { |
@@ -285,10 +312,12 @@ def search( |
285 | 312 | "search_results": { |
286 | 313 | "qps": qps, |
287 | 314 | "qps_mean": np.mean(qps), |
288 | | - "qps_rsd": np.std(qps, ddof=min(1, num_rep - 1)) / np.mean(qps), |
| 315 | + "qps_rsd": np.std(qps, ddof=min(1, num_rep - 1)) |
| 316 | + / np.mean(qps), |
289 | 317 | "p95": p95s, |
290 | 318 | "p95_mean": np.mean(p95s), |
291 | | - "p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1)) / np.mean(p95s), |
| 319 | + "p95_rsd": np.std(p95s, ddof=min(1, num_rep - 1)) |
| 320 | + / np.mean(p95s), |
292 | 321 | "search_parameters": { |
293 | 322 | "search_window_size": index.search_parameters.buffer_config.search_window_size, |
294 | 323 | "search_buffer_capacity": index.search_parameters.buffer_config.search_buffer_capacity, |
@@ -338,6 +367,10 @@ def main(argv: str | None = None) -> None: |
338 | 367 | calibration_ground_truth_path=args.calibration_ground_truth_file, |
339 | 368 | load_from_static=args.load_from_static, |
340 | 369 | lvq_strategy=consts.STR_TO_LVQ_STRATEGY.get(args.lvq_strategy, None), |
| 370 | + train_prefetchers=not args.no_calibrate_prefetchers, |
| 371 | + search_buffer_optimization=STR_TO_CALIBRATE_SEARCH_BUFFER[ |
| 372 | + args.calibrate_search_buffer |
| 373 | + ], |
341 | 374 | ) |
342 | 375 |
|
343 | 376 |
|
|
0 commit comments