Skip to content

Commit 761974e

Browse files
committed
Add cusomt command argument converters
1 parent eeaa7c2 commit 761974e

File tree

4 files changed

+152
-28
lines changed

4 files changed

+152
-28
lines changed

twitchio/ext/commands/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,4 @@
2828
from .cooldowns import *
2929
from .core import *
3030
from .exceptions import *
31+
from .converters import *

twitchio/ext/commands/bot.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636

3737
from ...utils import _is_submodule
3838
from .context import Context
39-
from .converters import _BaseConverter
4039
from .core import Command, CommandErrorPayload, Group, Mixin
4140
from .exceptions import *
4241

@@ -176,7 +175,6 @@ def __init__(
176175
self._owner_id: str | None = owner_id
177176
self._get_prefix: PrefixT = prefix
178177
self._components: dict[str, Component] = {}
179-
self._base_converter: _BaseConverter = _BaseConverter(self)
180178
self.__modules: dict[str, types.ModuleType] = {}
181179
self._owner: User | None = None
182180

twitchio/ext/commands/converters.py

Lines changed: 133 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -24,19 +24,21 @@
2424

2525
from __future__ import annotations
2626

27-
from typing import TYPE_CHECKING, Any
27+
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, runtime_checkable
2828

29+
from twitchio.ext.commands.context import Context
2930
from twitchio.user import User
31+
from twitchio.utils import Colour
3032

3133
from .exceptions import *
3234

3335

3436
if TYPE_CHECKING:
35-
from .bot import Bot
3637
from .context import Context
37-
from .types_ import BotT
3838

39-
__all__ = ("_BaseConverter",)
39+
40+
__all__ = ("Converter", "UserConverter")
41+
4042

4143
_BOOL_MAPPING: dict[str, bool] = {
4244
"true": True,
@@ -52,38 +54,86 @@
5254
}
5355

5456

55-
class _BaseConverter:
56-
def __init__(self, client: Bot) -> None:
57-
self.__client: Bot = client
57+
T_co = TypeVar("T_co", covariant=True)
5858

59-
self._MAPPING: dict[Any, Any] = {User: self._user}
60-
self._DEFAULTS: dict[type, Any] = {str: str, int: int, float: float, bool: self._bool, type(None): type(None)}
6159

62-
def _bool(self, arg: str) -> bool:
63-
try:
64-
result = _BOOL_MAPPING[arg.lower()]
65-
except KeyError:
66-
pretty: str = " | ".join(f'"{k}"' for k in _BOOL_MAPPING)
67-
raise BadArgument(f'Failed to convert "{arg}" to type bool. Expected any: [{pretty}]', value=arg)
60+
@runtime_checkable
61+
class Converter(Protocol[T_co]):
62+
"""Base class used to create custom argument converters in :class:`~twitchio.ext.commands.Command`'s.
6863
69-
return result
64+
To create a custom converter and do conversion logic on an argument you must override the :meth:`.convert` method.
65+
:meth:`.convert` must be a coroutine.
66+
67+
Examples
68+
--------
69+
70+
.. code:: python3
71+
72+
class LowerCaseConverter(commands.Converter[str]):
73+
74+
async def convert(self, ctx: commands.Context, arg: str) -> str:
75+
return arg.lower()
76+
77+
78+
@commands.command()
79+
async def test(ctx: commands.Context, arg: LowerCaseConverter) -> None: ...
80+
81+
82+
.. versionadded:: 3.1
83+
"""
84+
85+
async def convert(self, ctx: Context[Any], arg: str) -> T_co:
86+
"""|coro|
87+
88+
Method used on converters to implement conversion logic.
89+
90+
Parameters
91+
----------
92+
ctx: :class:`~twitchio.ext.commands.Context`
93+
The context provided to the converter after command invocation has started.
94+
arg: str
95+
The argument received in raw form as a :class:`str` and passed to the converter to do conversion logic on.
96+
"""
97+
raise NotImplementedError("Classes that derive from Converter must implement this method.")
98+
99+
100+
class UserConverter(Converter[User]):
101+
"""The converter used to convert command arguments to a :class:`twitchio.User`.
102+
103+
This is a default converter which can be used in commands by annotating arguments with the :class:`twitchio.User` type.
104+
105+
.. note::
106+
107+
This converter uses an API call to attempt to fetch a valid :class:`twitchio.User`.
108+
109+
110+
Example
111+
-------
112+
113+
.. code:: python3
114+
115+
@commands.command()
116+
async def test(ctx: commands.Context, *, user: twitchio.User) -> None: ...
117+
"""
118+
119+
async def convert(self, ctx: Context[Any], arg: str) -> User:
120+
client = ctx.bot
70121

71-
async def _user(self, context: Context[BotT], arg: str) -> User:
72122
arg = arg.lower()
73123
users: list[User]
74124
msg: str = 'Failed to convert "{}" to User. A User with the ID or login could not be found.'
75125

76126
if arg.startswith("@"):
77127
arg = arg.removeprefix("@")
78-
users = await self.__client.fetch_users(logins=[arg])
128+
users = await client.fetch_users(logins=[arg])
79129

80130
if not users:
81131
raise BadArgument(msg.format(arg), value=arg)
82132

83133
if arg.isdigit():
84-
users = await self.__client.fetch_users(logins=[arg], ids=[arg])
134+
users = await client.fetch_users(logins=[arg], ids=[arg])
85135
else:
86-
users = await self.__client.fetch_users(logins=[arg])
136+
users = await client.fetch_users(logins=[arg])
87137

88138
potential: list[User] = []
89139

@@ -99,3 +149,66 @@ async def _user(self, context: Context[BotT], arg: str) -> User:
99149
return potential[0]
100150

101151
raise BadArgument(msg.format(arg), value=arg)
152+
153+
154+
class ColourConverter(Converter[Colour]):
155+
"""The converter used to convert command arguments to a :class:`~twitchio.utils.Colour` object.
156+
157+
This is a default converter which can be used in commands by annotating arguments with the :class:`twitchio.utils.Colour` type.
158+
159+
This converter, attempts to convert ``hex`` and ``int`` type values only in the following formats:
160+
161+
- `"#FFDD00"`
162+
- `"FFDD00"`
163+
- `"0xFFDD00"`
164+
- `16768256`
165+
166+
167+
``hex`` values are attempted first, followed by ``int``.
168+
169+
.. note::
170+
171+
There is an alias to this converter named ``ColorConverter``.
172+
173+
Example
174+
-------
175+
176+
.. code:: python3
177+
178+
@commands.command()
179+
async def test(ctx: commands.Context, *, colour: twitchio.utils.Colour) -> None: ...
180+
181+
.. versionadded:: 3.1
182+
"""
183+
184+
async def convert(self, ctx: Context[Any], arg: str) -> Colour:
185+
try:
186+
result = Colour.from_hex(arg)
187+
except Exception:
188+
pass
189+
else:
190+
return result
191+
192+
try:
193+
result = Colour.from_int(int(arg))
194+
except Exception:
195+
raise ConversionError(f"Unable to convert to Colour. {arg!r} is not a valid hex or colour integer value.")
196+
197+
return result
198+
199+
200+
ColorConverter = ColourConverter
201+
202+
203+
def _bool(arg: str) -> bool:
204+
try:
205+
result = _BOOL_MAPPING[arg.lower()]
206+
except KeyError:
207+
pretty: str = " | ".join(f'"{k}"' for k in _BOOL_MAPPING)
208+
raise BadArgument(f'Failed to convert "{arg}" to type bool. Expected any: [{pretty}]', value=arg)
209+
210+
return result
211+
212+
213+
DEFAULT_CONVERTERS: dict[type, Any] = {str: str, int: int, float: float, bool: _bool, type(None): type(None)}
214+
CONVERTER_MAPPING: dict[Any, Converter[Any] | type[Converter[Any]]] = {User: UserConverter}

twitchio/ext/commands/core.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
from .exceptions import *
4141
from .types_ import CommandOptions, Component_T
4242

43+
from .converters import DEFAULT_CONVERTERS, CONVERTER_MAPPING, Converter
44+
4345

4446
__all__ = (
4547
"Command",
@@ -359,7 +361,7 @@ def _convert_literal_type(
359361

360362
for arg in reversed(args):
361363
type_: type[Any] = type(arg) # type: ignore
362-
if base := context.bot._base_converter._DEFAULTS.get(type_):
364+
if base := DEFAULT_CONVERTERS.get(type_):
363365
try:
364366
result = base(raw)
365367
except Exception:
@@ -377,6 +379,7 @@ async def _do_conversion(
377379
self, context: Context[BotT], param: inspect.Parameter, *, annotation: Any, raw: str | None
378380
) -> Any:
379381
name: str = param.name
382+
result: Any = MISSING
380383

381384
if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union:
382385
converters = list(annotation.__args__)
@@ -386,8 +389,6 @@ async def _do_conversion(
386389
except ValueError:
387390
pass
388391

389-
result: Any = MISSING
390-
391392
for c in reversed(converters):
392393
try:
393394
result = await self._do_conversion(context, param=param, annotation=c, raw=raw)
@@ -414,7 +415,7 @@ async def _do_conversion(
414415

415416
return result
416417

417-
base = context.bot._base_converter._DEFAULTS.get(annotation, None if annotation != param.empty else str)
418+
base = DEFAULT_CONVERTERS.get(annotation, None if annotation != param.empty else str)
418419
if base:
419420
try:
420421
result = base(raw)
@@ -423,13 +424,24 @@ async def _do_conversion(
423424

424425
return result
425426

426-
converter = context.bot._base_converter._MAPPING.get(annotation, annotation)
427+
converter = CONVERTER_MAPPING.get(annotation, annotation)
427428

428429
try:
429-
result = converter(context, raw)
430+
if inspect.isclass(converter) and issubclass(converter, Converter): # type: ignore
431+
if inspect.ismethod(converter.convert):
432+
result = converter.convert(context, raw)
433+
else:
434+
result = converter().convert(context, str(raw))
435+
elif isinstance(converter, Converter):
436+
result = converter.convert(context, str(raw))
437+
except CommandError:
438+
raise
430439
except Exception as e:
431440
raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw) from e
432441

442+
if result is MISSING:
443+
raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw)
444+
433445
if not asyncio.iscoroutine(result):
434446
return result
435447

0 commit comments

Comments
 (0)