Skip to content

Commit 20d521a

Browse files
author
Eviee Py
committed
Base guard implementation
1 parent 35e73ec commit 20d521a

File tree

2 files changed

+67
-29
lines changed

2 files changed

+67
-29
lines changed

twitchio/ext/commands/bot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,4 +138,4 @@ async def before_invoke(self, ctx: Context) -> None: ...
138138

139139
async def after_invoke(self, ctx: Context) -> None: ...
140140

141-
async def check(self, ctx: Context) -> None: ...
141+
async def guard(self, ctx: Context) -> None: ...

twitchio/ext/commands/core.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,7 @@
3636
from .types_ import CommandOptions, Component_T
3737

3838

39-
__all__ = (
40-
"CommandErrorPayload",
41-
"Command",
42-
"Mixin",
43-
"Group",
44-
"command",
45-
"group",
46-
)
39+
__all__ = ("CommandErrorPayload", "Command", "Mixin", "Group", "command", "group", "is_broadcaster")
4740

4841

4942
if TYPE_CHECKING:
@@ -55,6 +48,7 @@
5548

5649

5750
Coro: TypeAlias = Coroutine[Any, Any, None]
51+
CoroC: TypeAlias = Coroutine[Any, Any, bool]
5852

5953

6054
class CommandErrorPayload:
@@ -76,6 +70,7 @@ def __init__(
7670
self._name: str = name
7771
self._callback = callback
7872
self._aliases: list[str] = kwargs.get("aliases", [])
73+
self._guards: list[Callable[..., bool] | Callable[..., CoroC]] = getattr(self._callback, "__command_guards__", [])
7974

8075
self._injected: Component_T | None = None
8176
self._error: Callable[[Component_T, CommandErrorPayload], Coro] | Callable[[CommandErrorPayload], Coro] | None = None
@@ -116,6 +111,13 @@ def extras(self) -> dict[Any, Any]:
116111
def has_error(self) -> bool:
117112
return self._error is not None
118113

114+
@property
115+
def guards(self) -> ...: ...
116+
117+
@property
118+
def callback(self) -> Callable[Concatenate[Component_T, Context, P], Coro] | Callable[Concatenate[Context, P], Coro]:
119+
return self._callback
120+
119121
async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any:
120122
name: str = param.name
121123

@@ -181,16 +183,15 @@ async def _parse_arguments(self, context: Context) -> ...:
181183
if param.kind == param.KEYWORD_ONLY:
182184
raw = context._view.read_rest()
183185

184-
if not raw:
185-
if param.default == param.empty:
186-
raise MissingRequiredArgument(param=param)
186+
if raw:
187+
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
188+
kwargs[param.name] = result
189+
break
187190

188-
kwargs[param.name] = param.default
189-
continue
191+
if param.default == param.empty:
192+
raise MissingRequiredArgument(param=param)
190193

191-
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
192-
kwargs[param.name] = result
193-
break
194+
kwargs[param.name] = param.default
194195

195196
elif param.kind == param.VAR_POSITIONAL:
196197
packed: list[Any] = []
@@ -211,23 +212,25 @@ async def _parse_arguments(self, context: Context) -> ...:
211212
raw = context._view.get_quoted_word()
212213
context._view.skip_ws()
213214

214-
if not raw:
215-
if param.default == param.empty:
216-
raise MissingRequiredArgument(param=param)
217-
218-
args.append(param.default)
215+
if raw:
216+
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
217+
args.append(result)
219218
continue
220219

221-
result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation)
222-
args.append(result)
220+
if param.default == param.empty:
221+
raise MissingRequiredArgument(param=param)
222+
223+
args.append(param.default)
223224

224225
return args, kwargs
225226

226-
async def _do_checks(self, context: Context) -> ...:
227-
# Bot
228-
# Component
229-
# Command
230-
...
227+
async def _run_guards(self, context: Context) -> ...:
228+
# TODO ...
229+
for guard in self._guards:
230+
result = guard(context)
231+
232+
if not result:
233+
raise CheckFailure
231234

232235
async def _invoke(self, context: Context) -> None:
233236
try:
@@ -243,6 +246,8 @@ async def _invoke(self, context: Context) -> None:
243246
args: list[Any] = [context, *args]
244247
args.insert(0, self._injected) if self._injected else None
245248

249+
await self._run_guards(context)
250+
246251
callback = self._callback(*args, **kwargs) # type: ignore
247252

248253
try:
@@ -284,6 +289,16 @@ def error(
284289
self._error = func
285290
return func
286291

292+
def add_guard(self) -> None: ...
293+
294+
def remove_guard(
295+
self,
296+
) -> None: ...
297+
298+
def before_invoke(self) -> None: ...
299+
300+
def after_invoke(self) -> None: ...
301+
287302

288303
class Mixin(Generic[Component_T]):
289304
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -427,3 +442,26 @@ def wrapper(
427442
return new
428443

429444
return wrapper
445+
446+
447+
def guard(predicate: Callable[..., bool] | Callable[..., CoroC]) -> Any:
448+
def wrapper(func: Any) -> Any:
449+
if isinstance(func, Command):
450+
func._guards.append(predicate)
451+
452+
else:
453+
try:
454+
func.__command_guards__.append(predicate)
455+
except AttributeError:
456+
func.__command_guards__ = [predicate]
457+
458+
return func # type: ignore
459+
460+
return wrapper
461+
462+
463+
def is_broadcaster() -> Any:
464+
def predicate(context: Context) -> bool:
465+
return context.chatter.id == context.broadcaster.id
466+
467+
return guard(predicate)

0 commit comments

Comments
 (0)