|
29 | 29 | import inspect |
30 | 30 | from collections.abc import Callable, Coroutine, Generator |
31 | 31 | from types import MappingProxyType, UnionType |
32 | | -from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack |
| 32 | +from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack, overload |
33 | 33 |
|
34 | 34 | from twitchio.utils import MISSING |
35 | 35 |
|
|
67 | 67 | Coro: TypeAlias = Coroutine[Any, Any, None] |
68 | 68 | CoroC: TypeAlias = Coroutine[Any, Any, bool] |
69 | 69 |
|
| 70 | +DT = TypeVar("DT") |
| 71 | +VT = TypeVar("VT") |
| 72 | + |
70 | 73 |
|
71 | 74 | class CommandErrorPayload: |
72 | 75 | """Payload received in the :func:`~twitchio.event_command_error` event. |
@@ -395,9 +398,17 @@ def after_invoke(self) -> None: ... |
395 | 398 |
|
396 | 399 | class Mixin(Generic[Component_T]): |
397 | 400 | def __init__(self, *args: Any, **kwargs: Any) -> None: |
398 | | - self._commands: dict[str, Command[Component_T, ...]] = {} |
| 401 | + case_: bool = kwargs.pop("case_insensitive", False) |
| 402 | + self._case_insensitive: bool = case_ |
| 403 | + self._commands: dict[str, Command[Component_T, ...]] = {} if not case_ else _CaseInsensitiveDict() |
| 404 | + |
399 | 405 | super().__init__(*args, **kwargs) |
400 | 406 |
|
| 407 | + @property |
| 408 | + def case_insensitive(self) -> bool: |
| 409 | + """Property returning a bool indicating whether this Mixin is using case insensitive commands.""" |
| 410 | + return self._case_insensitive |
| 411 | + |
401 | 412 | def add_command(self, command: Command[Component_T, ...], /) -> None: |
402 | 413 | """Add a :class:`~.commands.Command` object to the mixin. |
403 | 414 |
|
@@ -986,3 +997,44 @@ def predicate(context: Context) -> bool: |
986 | 997 | return chatter.moderator or chatter.vip |
987 | 998 |
|
988 | 999 | return guard(predicate) |
| 1000 | + |
| 1001 | + |
| 1002 | +class _CaseInsensitiveDict(dict[str, VT]): |
| 1003 | + def __contains__(self, key: object) -> bool: |
| 1004 | + return super().__contains__(key.casefold()) if isinstance(key, str) else False |
| 1005 | + |
| 1006 | + def __delitem__(self, key: str) -> None: |
| 1007 | + return super().__delitem__(key.casefold()) |
| 1008 | + |
| 1009 | + def __getitem__(self, key: str) -> VT: |
| 1010 | + return super().__getitem__(key.casefold()) |
| 1011 | + |
| 1012 | + @overload |
| 1013 | + def get(self, key: str, /) -> VT | None: ... |
| 1014 | + |
| 1015 | + @overload |
| 1016 | + def get(self, key: str, default: VT, /) -> VT: ... |
| 1017 | + |
| 1018 | + @overload |
| 1019 | + def get(self, key: str, default: DT, /) -> VT | DT: ... |
| 1020 | + |
| 1021 | + def get(self, key: str, default: DT = None) -> VT | DT: |
| 1022 | + return super().get(key.casefold(), default) |
| 1023 | + |
| 1024 | + @overload |
| 1025 | + def pop(self, key: str, /) -> VT: ... |
| 1026 | + |
| 1027 | + @overload |
| 1028 | + def pop(self, key: str, default: VT, /) -> VT: ... |
| 1029 | + |
| 1030 | + @overload |
| 1031 | + def pop(self, key: str, default: DT, /) -> VT | DT: ... |
| 1032 | + |
| 1033 | + def pop(self, key: str, default: DT = MISSING) -> VT | DT: |
| 1034 | + if default is MISSING: |
| 1035 | + return super().pop(key.casefold()) |
| 1036 | + |
| 1037 | + return super().pop(key.casefold(), default) |
| 1038 | + |
| 1039 | + def __setitem__(self, key: str, value: VT) -> None: |
| 1040 | + super().__setitem__(key.casefold(), value) |
0 commit comments