|
18 | 18 | CALC_DIST_L2, |
19 | 19 | CALC_DIST_TANIMOTO, |
20 | 20 | DEFAULT_SEARCH_EXTENSION_RATE, |
| 21 | + EF, |
21 | 22 | FIELDS, |
22 | 23 | INT64_MAX, |
23 | 24 | ITERATION_EXTENSION_REDUCE_RATE, |
|
39 | 40 | SearchIterator = TypeVar("SearchIterator") |
40 | 41 |
|
41 | 42 |
|
42 | | -def extend_batch_size(batch_size: int) -> int: |
| 43 | +def extend_batch_size(batch_size: int, next_param: dict) -> int: |
| 44 | + if EF in next_param[PARAMS]: |
| 45 | + return min( |
| 46 | + MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE, next_param[PARAMS][EF] |
| 47 | + ) |
43 | 48 | return min(MAX_BATCH_SIZE, batch_size * DEFAULT_SEARCH_EXTENSION_RATE) |
44 | 49 |
|
45 | 50 |
|
@@ -294,6 +299,7 @@ def __init__( |
294 | 299 | } |
295 | 300 | self._expr = expr |
296 | 301 | self.__check_set_params(param) |
| 302 | + self.__check_for_special_index_param() |
297 | 303 | self._kwargs = kwargs |
298 | 304 | self._filtered_ids = [] |
299 | 305 | self._filtered_distance = None |
@@ -337,6 +343,15 @@ def __check_set_params(self, param: Dict): |
337 | 343 | if PARAMS not in self._param: |
338 | 344 | self._param[PARAMS] = {} |
339 | 345 |
|
| 346 | + def __check_for_special_index_param(self): |
| 347 | + if ( |
| 348 | + EF in self._param[PARAMS] |
| 349 | + and self._param[PARAMS][EF] < self._iterator_params[BATCH_SIZE] |
| 350 | + ): |
| 351 | + raise MilvusException( |
| 352 | + message="When using hnsw index, provided ef must be larger than or equal to batch size" |
| 353 | + ) |
| 354 | + |
340 | 355 | def __setup__pk_prop(self): |
341 | 356 | fields = self._schema[FIELDS] |
342 | 357 | for field in fields: |
@@ -472,7 +487,7 @@ def __execute_next_search(self, next_params: dict, next_expr: str) -> SearchPage |
472 | 487 | self._iterator_params["data"], |
473 | 488 | self._iterator_params["ann_field"], |
474 | 489 | next_params, |
475 | | - extend_batch_size(self._iterator_params[BATCH_SIZE]), |
| 490 | + extend_batch_size(self._iterator_params[BATCH_SIZE], next_params), |
476 | 491 | next_expr, |
477 | 492 | self._iterator_params["partition_names"], |
478 | 493 | self._iterator_params["output_fields"], |
|
0 commit comments