|
14 | 14 | import uuid |
15 | 15 | import typing |
16 | 16 | import typing_extensions |
17 | | -from collections import defaultdict, deque |
| 17 | +from collections import defaultdict, deque, Counter |
18 | 18 | from collections.abc import Iterator, Sequence, MutableSequence |
19 | 19 | from collections.abc import Mapping, MutableMapping, Set, MutableSet |
20 | 20 | from dataclasses import is_dataclass |
@@ -218,6 +218,13 @@ def typename(typ: Any, with_typing_module: bool = False) -> str: |
218 | 218 | return f"deque[{et}]" |
219 | 219 | else: |
220 | 220 | return "deque" |
| 221 | + elif is_counter(typ): |
| 222 | + args = type_args(typ) |
| 223 | + if args: |
| 224 | + et = thisfunc(args[0]) |
| 225 | + return f"Counter[{et}]" |
| 226 | + else: |
| 227 | + return "Counter" |
221 | 228 | elif is_tuple(typ): |
222 | 229 | args = type_args(typ) |
223 | 230 | if args: |
@@ -373,6 +380,11 @@ def recursive(cls: Union[type[Any], Any]) -> None: |
373 | 380 | args = type_args(cls) |
374 | 381 | if args: |
375 | 382 | recursive(args[0]) |
| 383 | + elif is_counter(cls): |
| 384 | + lst.add(Counter) |
| 385 | + args = type_args(cls) |
| 386 | + if args: |
| 387 | + recursive(args[0]) |
376 | 388 | elif is_tuple(cls): |
377 | 389 | lst.add(tuple) |
378 | 390 | for arg in type_args(cls): |
@@ -419,7 +431,7 @@ def recursive(cls: TypeLike) -> None: |
419 | 431 | args = type_args(cls) |
420 | 432 | if args: |
421 | 433 | recursive(args[0]) |
422 | | - elif is_list(cls) or is_set(cls) or is_deque(cls): |
| 434 | + elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): |
423 | 435 | args = type_args(cls) |
424 | 436 | if args: |
425 | 437 | recursive(args[0]) |
@@ -462,7 +474,7 @@ def recursive(cls: Union[type[Any], Any]) -> None: |
462 | 474 | args = type_args(cls) |
463 | 475 | if args: |
464 | 476 | recursive(args[0]) |
465 | | - elif is_list(cls) or is_set(cls) or is_deque(cls): |
| 477 | + elif is_list(cls) or is_set(cls) or is_deque(cls) or is_counter(cls): |
466 | 478 | args = type_args(cls) |
467 | 479 | if args: |
468 | 480 | recursive(args[0]) |
@@ -830,6 +842,40 @@ def is_bare_deque(typ: type[Any]) -> bool: |
830 | 842 | return typ is deque |
831 | 843 |
|
832 | 844 |
|
| 845 | +@cache |
| 846 | +def is_counter(typ: type[Any]) -> bool: |
| 847 | + """ |
| 848 | + Test if the type is `collections.Counter`. |
| 849 | +
|
| 850 | + >>> is_counter(Counter[str]) |
| 851 | + True |
| 852 | + >>> is_counter(Counter) |
| 853 | + True |
| 854 | + >>> is_counter(dict[str, int]) |
| 855 | + False |
| 856 | + """ |
| 857 | + try: |
| 858 | + return issubclass(get_origin(typ), Counter) # type: ignore |
| 859 | + except TypeError: |
| 860 | + return typ is Counter |
| 861 | + |
| 862 | + |
| 863 | +@cache |
| 864 | +def is_bare_counter(typ: type[Any]) -> bool: |
| 865 | + """ |
| 866 | + Test if the type is `collections.Counter` without type args. |
| 867 | +
|
| 868 | + >>> is_bare_counter(Counter[str]) |
| 869 | + False |
| 870 | + >>> is_bare_counter(Counter) |
| 871 | + True |
| 872 | + """ |
| 873 | + origin = get_origin(typ) |
| 874 | + if origin is Counter: |
| 875 | + return not type_args(typ) |
| 876 | + return typ is Counter |
| 877 | + |
| 878 | + |
833 | 879 | @cache |
834 | 880 | def is_none(typ: type[Any]) -> bool: |
835 | 881 | """ |
|
0 commit comments