3636from .types_ import CommandOptions , Component_T
3737
3838
39- __all__ = (
40- "CommandErrorPayload" ,
41- "Command" ,
42- "Mixin" ,
43- "Group" ,
44- "command" ,
45- "group" ,
46- )
39+ __all__ = ("CommandErrorPayload" , "Command" , "Mixin" , "Group" , "command" , "group" , "is_broadcaster" )
4740
4841
4942if TYPE_CHECKING :
5548
5649
5750Coro : TypeAlias = Coroutine [Any , Any , None ]
51+ CoroC : TypeAlias = Coroutine [Any , Any , bool ]
5852
5953
6054class CommandErrorPayload :
@@ -76,6 +70,7 @@ def __init__(
7670 self ._name : str = name
7771 self ._callback = callback
7872 self ._aliases : list [str ] = kwargs .get ("aliases" , [])
73+ self ._guards : list [Callable [..., bool ] | Callable [..., CoroC ]] = getattr (self ._callback , "__command_guards__" , [])
7974
8075 self ._injected : Component_T | None = None
8176 self ._error : Callable [[Component_T , CommandErrorPayload ], Coro ] | Callable [[CommandErrorPayload ], Coro ] | None = None
@@ -116,6 +111,13 @@ def extras(self) -> dict[Any, Any]:
116111 def has_error (self ) -> bool :
117112 return self ._error is not None
118113
114+ @property
115+ def guards (self ) -> ...: ...
116+
117+ @property
118+ def callback (self ) -> Callable [Concatenate [Component_T , Context , P ], Coro ] | Callable [Concatenate [Context , P ], Coro ]:
119+ return self ._callback
120+
119121 async def _do_conversion (self , context : Context , param : inspect .Parameter , * , annotation : Any , raw : str | None ) -> Any :
120122 name : str = param .name
121123
@@ -181,16 +183,15 @@ async def _parse_arguments(self, context: Context) -> ...:
181183 if param .kind == param .KEYWORD_ONLY :
182184 raw = context ._view .read_rest ()
183185
184- if not raw :
185- if param .default == param .empty :
186- raise MissingRequiredArgument (param = param )
186+ if raw :
187+ result = await self ._do_conversion (context , param = param , raw = raw , annotation = param .annotation )
188+ kwargs [param .name ] = result
189+ break
187190
188- kwargs [ param .name ] = param .default
189- continue
191+ if param .default == param .empty :
192+ raise MissingRequiredArgument ( param = param )
190193
191- result = await self ._do_conversion (context , param = param , raw = raw , annotation = param .annotation )
192- kwargs [param .name ] = result
193- break
194+ kwargs [param .name ] = param .default
194195
195196 elif param .kind == param .VAR_POSITIONAL :
196197 packed : list [Any ] = []
@@ -211,23 +212,25 @@ async def _parse_arguments(self, context: Context) -> ...:
211212 raw = context ._view .get_quoted_word ()
212213 context ._view .skip_ws ()
213214
214- if not raw :
215- if param .default == param .empty :
216- raise MissingRequiredArgument (param = param )
217-
218- args .append (param .default )
215+ if raw :
216+ result = await self ._do_conversion (context , param = param , raw = raw , annotation = param .annotation )
217+ args .append (result )
219218 continue
220219
221- result = await self ._do_conversion (context , param = param , raw = raw , annotation = param .annotation )
222- args .append (result )
220+ if param .default == param .empty :
221+ raise MissingRequiredArgument (param = param )
222+
223+ args .append (param .default )
223224
224225 return args , kwargs
225226
226- async def _do_checks (self , context : Context ) -> ...:
227- # Bot
228- # Component
229- # Command
230- ...
227+ async def _run_guards (self , context : Context ) -> ...:
228+ # TODO ...
229+ for guard in self ._guards :
230+ result = guard (context )
231+
232+ if not result :
233+ raise CheckFailure
231234
232235 async def _invoke (self , context : Context ) -> None :
233236 try :
@@ -243,6 +246,8 @@ async def _invoke(self, context: Context) -> None:
243246 args : list [Any ] = [context , * args ]
244247 args .insert (0 , self ._injected ) if self ._injected else None
245248
249+ await self ._run_guards (context )
250+
246251 callback = self ._callback (* args , ** kwargs ) # type: ignore
247252
248253 try :
@@ -284,6 +289,16 @@ def error(
284289 self ._error = func
285290 return func
286291
292+ def add_guard (self ) -> None : ...
293+
294+ def remove_guard (
295+ self ,
296+ ) -> None : ...
297+
298+ def before_invoke (self ) -> None : ...
299+
300+ def after_invoke (self ) -> None : ...
301+
287302
288303class Mixin (Generic [Component_T ]):
289304 def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
@@ -427,3 +442,26 @@ def wrapper(
427442 return new
428443
429444 return wrapper
445+
446+
447+ def guard (predicate : Callable [..., bool ] | Callable [..., CoroC ]) -> Any :
448+ def wrapper (func : Any ) -> Any :
449+ if isinstance (func , Command ):
450+ func ._guards .append (predicate )
451+
452+ else :
453+ try :
454+ func .__command_guards__ .append (predicate )
455+ except AttributeError :
456+ func .__command_guards__ = [predicate ]
457+
458+ return func # type: ignore
459+
460+ return wrapper
461+
462+
463+ def is_broadcaster () -> Any :
464+ def predicate (context : Context ) -> bool :
465+ return context .chatter .id == context .broadcaster .id
466+
467+ return guard (predicate )
0 commit comments