Skip to content

Commit 01ced9d

Browse files
committed
Start of commands. argument parsing.
1 parent fdf1d02 commit 01ced9d

File tree

6 files changed

+254
-4
lines changed

6 files changed

+254
-4
lines changed

twitchio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,5 @@
3535
from .exceptions import *
3636
from .models import *
3737
from .payloads import *
38+
from .user import *
3839
from .utils import Color as Color, Colour as Colour

twitchio/ext/commands/bot.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from twitchio.client import Client
3131

3232
from .context import Context
33+
from .converters import _BaseConverter
3334
from .core import Command, CommandErrorPayload, Group, Mixin
3435
from .exceptions import *
3536

@@ -65,6 +66,7 @@ def __init__(
6566

6667
self._get_prefix: Prefix_T = prefix
6768
self._components: dict[str, Component] = {}
69+
self._base_converter: _BaseConverter = _BaseConverter(self)
6870

6971
@property
7072
def bot_id(self) -> str:
@@ -135,3 +137,5 @@ async def event_command_error(self, payload: CommandErrorPayload) -> None:
135137
async def before_invoke(self, ctx: Context) -> None: ...
136138

137139
async def after_invoke(self, ctx: Context) -> None: ...
140+
141+
async def check(self, ctx: Context) -> None: ...

twitchio/ext/commands/context.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def __init__(self, message: ChatMessage, bot: Bot) -> None:
5959

6060
self._view: StringView = StringView(self._raw_content)
6161

62+
self._args: list[Any] = []
63+
self._kwargs: dict[str, Any] = {}
64+
6265
@property
6366
def message(self) -> ChatMessage:
6467
return self._message
@@ -111,6 +114,14 @@ def error_dispatched(self) -> bool:
111114
def error_dispatched(self, value: bool, /) -> None:
112115
self._error_dispatched = value
113116

117+
@property
118+
def args(self) -> list[Any]:
119+
return self._args
120+
121+
@property
122+
def kwargs(self) -> dict[str, Any]:
123+
return self._kwargs
124+
114125
def is_valid(self) -> bool:
115126
return self._prefix is not None
116127

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2017 - Present PythonistaGuild
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
from typing import TYPE_CHECKING, Any
28+
29+
from twitchio.user import User
30+
31+
from .exceptions import *
32+
33+
34+
if TYPE_CHECKING:
35+
from .bot import Bot
36+
from .context import Context
37+
38+
__all__ = ("_BaseConverter",)
39+
40+
41+
class _BaseConverter:
42+
def __init__(self, client: Bot) -> None:
43+
self.__client: Bot = client
44+
45+
self._MAPPING: dict[Any, Any] = {User: self._user}
46+
self._DEFAULTS: dict[type, type] = {str: str, int: int, float: float}
47+
48+
async def _user(self, context: Context, arg: str) -> User:
49+
arg = arg.lower()
50+
users: list[User]
51+
msg: str = 'Failed to convert "{}" to User. A User with the ID or login could not be found.'
52+
53+
if arg.startswith("@"):
54+
arg = arg.removeprefix("@")
55+
users = await self.__client.fetch_users(logins=[arg])
56+
57+
if not users:
58+
raise BadArgument(msg.format(arg), value=arg)
59+
60+
if arg.isdigit():
61+
users = await self.__client.fetch_users(logins=[arg], ids=[arg])
62+
else:
63+
users = await self.__client.fetch_users(logins=[arg])
64+
65+
potential: list[User] = []
66+
67+
for user in users:
68+
# ID's should be taken into consideration first...
69+
if user.id == arg:
70+
return user
71+
72+
elif user.name == arg:
73+
potential.append(user)
74+
75+
if potential:
76+
return potential[0]
77+
78+
raise BadArgument(msg.format(arg), value=arg)

twitchio/ext/commands/core.py

Lines changed: 135 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@
2525
from __future__ import annotations
2626

2727
import asyncio
28+
import inspect
2829
from collections.abc import Callable, Coroutine
29-
from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Unpack
30+
from types import UnionType
31+
from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack
32+
33+
from twitchio.utils import MISSING
3034

3135
from .exceptions import *
3236
from .types_ import CommandOptions, Component_T
@@ -112,10 +116,134 @@ def extras(self) -> dict[Any, Any]:
112116
def has_error(self) -> bool:
113117
return self._error is not None
114118

119+
async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any:
120+
name: str = param.name
121+
122+
if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union:
123+
converters = list(annotation.__args__)
124+
converters.remove(type(None))
125+
126+
result: Any = MISSING
127+
128+
for c in converters:
129+
try:
130+
result = await self._do_conversion(context, param=param, annotation=c, raw=raw)
131+
except Exception:
132+
continue
133+
134+
if result is MISSING:
135+
raise BadArgument(
136+
f'Failed to convert argument "{name}" with any converter from Union: {converters}. No default value was provided.',
137+
name=name,
138+
value=raw,
139+
)
140+
141+
return result
142+
143+
base = context.bot._base_converter._DEFAULTS.get(annotation, None if annotation != param.empty else str)
144+
if base:
145+
try:
146+
result = base(raw)
147+
except Exception as e:
148+
raise BadArgument(f'Failed to convert "{name}" to {base}', name=name, value=raw) from e
149+
150+
return result
151+
152+
converter = context.bot._base_converter._MAPPING.get(annotation, annotation)
153+
154+
try:
155+
result = converter(context, raw)
156+
except Exception as e:
157+
raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw) from e
158+
159+
if not asyncio.iscoroutine(result):
160+
return result
161+
162+
try:
163+
result = await result
164+
except Exception as e:
165+
raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw) from e
166+
167+
return result
168+
169+
async def _parse_arguments(self, context: Context) -> ...:
170+
context._view.skip_ws()
171+
signature: inspect.Signature = inspect.signature(self._callback)
172+
173+
# We expect context always and self with commands in components...
174+
skip: int = 1 if not self._injected else 2
175+
params: list[inspect.Parameter] = list(signature.parameters.values())[skip:]
176+
177+
args: list[Any] = []
178+
kwargs = {}
179+
180+
for param in params:
181+
if param.kind == param.KEYWORD_ONLY:
182+
raw = context._view.read_rest()
183+
184+
if not raw:
185+
if param.default == param.empty:
186+
raise MissingRequiredArgument(param=param)
187+
188+
kwargs[param.name] = param.default
189+
continue
190+
191+
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
192+
kwargs[param.name] = result
193+
break
194+
195+
elif param.kind == param.VAR_POSITIONAL:
196+
packed: list[Any] = []
197+
198+
while True:
199+
context._view.skip_ws()
200+
raw = context._view.get_quoted_word()
201+
if not raw:
202+
break
203+
204+
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
205+
packed.append(result)
206+
207+
args.extend(packed)
208+
break
209+
210+
elif param.kind == param.POSITIONAL_OR_KEYWORD:
211+
raw = context._view.get_quoted_word()
212+
context._view.skip_ws()
213+
214+
if not raw:
215+
if param.default == param.empty:
216+
raise MissingRequiredArgument(param=param)
217+
218+
args.append(param.default)
219+
continue
220+
221+
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
222+
args.append(result)
223+
224+
return args, kwargs
225+
226+
async def _do_checks(self, context: Context) -> ...:
227+
# Bot
228+
# Component
229+
# Command
230+
...
231+
115232
async def _invoke(self, context: Context) -> None:
116-
# TODO: Argument parsing...
117-
# TODO: Checks... Including cooldowns...
118-
callback = self._callback(self._injected, context) if self._injected else self._callback(context) # type: ignore
233+
try:
234+
args, kwargs = await self._parse_arguments(context)
235+
except (ConversionError, MissingRequiredArgument):
236+
raise
237+
except Exception as e:
238+
raise ConversionError("An unknown error occurred converting arguments.") from e
239+
240+
context._args = args
241+
context._kwargs = kwargs
242+
243+
args: list[Any] = [context, *args]
244+
args.insert(0, self._injected) if self._injected else None
245+
246+
callback = self._callback(*args, **kwargs) # type: ignore
119247

120248
try:
121249
await callback
@@ -127,6 +255,9 @@ async def invoke(self, context: Context) -> None:
127255
await self._invoke(context)
128256
except CommandError as e:
129257
await self._dispatch_error(context, e)
258+
except Exception as e:
259+
error = CommandInvokeError(str(e), original=e)
260+
await self._dispatch_error(context, error)
130261

131262
async def _dispatch_error(self, context: Context, exception: CommandError) -> None:
132263
payload = CommandErrorPayload(context=context, exception=exception)

twitchio/ext/commands/exceptions.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
SOFTWARE.
2323
"""
2424

25+
import inspect
26+
2527
from twitchio.exceptions import TwitchioException
2628

2729

@@ -35,6 +37,10 @@
3537
"PrefixError",
3638
"InputError",
3739
"ArgumentError",
40+
"CheckFailure",
41+
"ConversionError",
42+
"BadArgument",
43+
"MissingRequiredArgument",
3844
)
3945

4046

@@ -84,3 +90,22 @@ class ExpectedClosingQuoteError(ArgumentError):
8490
def __init__(self, close_quote: str) -> None:
8591
self.close_quote: str = close_quote
8692
super().__init__(f"Expected closing {close_quote}.")
93+
94+
95+
class CheckFailure(CommandError): ...
96+
97+
98+
class ConversionError(ArgumentError): ...
99+
100+
101+
class BadArgument(ConversionError):
102+
def __init__(self, msg: str, *, name: str | None = None, value: str | None) -> None:
103+
self.name: str | None = name
104+
self.value: str | None = value
105+
super().__init__(msg)
106+
107+
108+
class MissingRequiredArgument(ArgumentError):
109+
def __init__(self, param: inspect.Parameter) -> None:
110+
self.param: inspect.Parameter = param
111+
super().__init__(f'"{param.name}" is a required argument which is missing.')

0 commit comments

Comments
 (0)