|
6 | 6 | from collections import defaultdict
|
7 | 7 | from dataclasses import dataclass, field
|
8 | 8 | from functools import cached_property, reduce
|
9 |
| -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional |
| 9 | +from typing import (TYPE_CHECKING, Any, Callable, DefaultDict, Dict, List, |
| 10 | + Mapping, Optional) |
10 | 11 | from typing import Sequence as GenericSequence
|
11 | 12 | from typing import Set, Tuple, Union, cast
|
12 | 13 |
|
@@ -256,7 +257,8 @@ def output_token_ids(self) -> Tuple[int, ...]:
|
256 | 257 | return tuple(self._output_token_ids)
|
257 | 258 |
|
258 | 259 | @output_token_ids.setter
|
259 |
| - def output_token_ids(self, new_output_token_ids: List[int]) -> None: |
| 260 | + def output_token_ids(self, |
| 261 | + new_output_token_ids: GenericSequence[int]) -> None: |
260 | 262 | self._output_token_ids = array(VLLM_TOKEN_ID_ARRAY_TYPE,
|
261 | 263 | new_output_token_ids)
|
262 | 264 | self._update_cached_all_tokens()
|
@@ -1173,7 +1175,7 @@ def get_all_seq_ids_and_request_ids(
|
1173 | 1175 | sequence ids.
|
1174 | 1176 | """
|
1175 | 1177 | seq_ids: List[int] = []
|
1176 |
| - request_id_seq_ids_mapping: Dict[str, Set[int]] = defaultdict(set) |
| 1178 | + request_id_seq_ids_mapping: DefaultDict[str, Set[int]] = defaultdict(set) |
1177 | 1179 | for sg in seq_group_metadata_list:
|
1178 | 1180 | for seq_id in sg.seq_data:
|
1179 | 1181 | seq_ids.append(seq_id)
|
|
0 commit comments