Skip to content

Commit 3738f76

Browse files
committed
Add command cooldowns implementation
1 parent 93e7fcb commit 3738f76

File tree

5 files changed

+496
-20
lines changed

5 files changed

+496
-20
lines changed

twitchio/ext/commands/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@
2525
from .bot import Bot as Bot
2626
from .components import *
2727
from .context import *
28+
from .cooldowns import *
2829
from .core import *
2930
from .exceptions import *

twitchio/ext/commands/cooldowns.py

Lines changed: 280 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,280 @@
1+
"""
2+
MIT License
3+
4+
Copyright (c) 2017 - Present PythonistaGuild
5+
6+
Permission is hereby granted, free of charge, to any person obtaining a copy
7+
of this software and associated documentation files (the "Software"), to deal
8+
in the Software without restriction, including without limitation the rights
9+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
copies of the Software, and to permit persons to whom the Software is
11+
furnished to do so, subject to the following conditions:
12+
13+
The above copyright notice and this permission notice shall be included in all
14+
copies or substantial portions of the Software.
15+
16+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22+
SOFTWARE.
23+
"""
24+
25+
from __future__ import annotations
26+
27+
import abc
28+
import asyncio
29+
import datetime
30+
import enum
31+
from collections.abc import Callable, Coroutine, Hashable
32+
from typing import TYPE_CHECKING, Any, Generic, Self, TypeAlias, TypeVar
33+
34+
35+
if TYPE_CHECKING:
36+
import twitchio
37+
38+
from .context import Context
39+
40+
41+
__all__ = ("BaseCooldown", "Cooldown", "GCRACooldown", "BucketType", "Bucket")
42+
43+
44+
PT = TypeVar("PT")
45+
CT = TypeVar("CT")
46+
47+
48+
class BucketType(enum.Enum):
49+
"""Enum representing default implementations for the key argument in :func:`~.commands.cooldown`.
50+
51+
Attributes
52+
----------
53+
default
54+
The cooldown will be considered a global cooldown shared across every channel and user.
55+
user
56+
The cooldown will apply per user, accross all channels.
57+
channel
58+
The cooldown will apply to every user/chatter in the channel.
59+
chatter
60+
The cooldown will apply per user, per channel.
61+
"""
62+
63+
default = 0
64+
user = 1
65+
channel = 2
66+
chatter = 3
67+
68+
def get_key(self, payload: twitchio.ChatMessage | Context) -> Any:
69+
if self is BucketType.user:
70+
return payload.chatter.id
71+
72+
elif self is BucketType.channel:
73+
return ("channel", payload.broadcaster.id)
74+
75+
elif self is BucketType.chatter:
76+
return (payload.broadcaster.id, payload.chatter.id)
77+
78+
def __call__(self, payload: twitchio.ChatMessage | Context) -> Any:
79+
return self.get_key(payload)
80+
81+
82+
class BaseCooldown(abc.ABC):
83+
"""Base class used to implement your own cooldown algorithm for use with :func:`~.commands.cooldown`.
84+
85+
Some built-in cooldown algorithms already exist:
86+
87+
- :class:`~.commands.Cooldown` - (``Token Bucket Algorithm``)
88+
89+
- :class:`~.commands.GCRACooldown` - (``Generic Cell Rate Algorithm``)
90+
91+
92+
.. note::
93+
94+
Every base method must be implemented in this base class.
95+
"""
96+
97+
@abc.abstractmethod
98+
def reset(self) -> None:
99+
"""Base method which should be implemented to reset the cooldown."""
100+
raise NotImplementedError
101+
102+
@abc.abstractmethod
103+
def update(self) -> float | None:
104+
"""Base method which should be implemented to update the cooldown/ratelimit.
105+
106+
This is where your algorithm logic should be contained.
107+
108+
.. important::
109+
110+
This method should always return a :class:`float` or ``None``. If ``None`` is returned by this method,
111+
the cooldown will be considered bypassed.
112+
113+
Returns
114+
-------
115+
:class:`float`
116+
The time needed to wait before you are off cooldown.
117+
``None``
118+
Bypasses the cooldown.
119+
"""
120+
raise NotImplementedError
121+
122+
@abc.abstractmethod
123+
def copy(self) -> Self:
124+
"""Base method which should be implemented to return a copy of this class in it's original state."""
125+
raise NotImplementedError
126+
127+
@abc.abstractmethod
128+
def is_ratelimited(self) -> bool:
129+
"""Base method which should be implemented which returns a bool indicating whether the cooldown is ratelimited.
130+
131+
Returns
132+
-------
133+
bool
134+
A bool indicating whether this cooldown is currently ratelimited.
135+
"""
136+
raise NotImplementedError
137+
138+
@abc.abstractmethod
139+
def is_dead(self) -> bool:
140+
"""Base method which should be implemented to indicate whether the cooldown should be considered stale and allowed
141+
to be removed from the ``bucket: cooldown`` mapping.
142+
143+
Returns
144+
-------
145+
bool
146+
A bool indicating whether this cooldown is stale/old.
147+
"""
148+
raise NotImplementedError
149+
150+
151+
class Cooldown(BaseCooldown):
152+
"""Default cooldown algorithm for :func:`~.commands.cooldown`, which implements a ``Token Bucket Algorithm``.
153+
154+
See: :func:`~.commands.cooldown` for more documentation.
155+
"""
156+
157+
def __init__(self, *, rate: int, per: int | datetime.timedelta) -> None:
158+
if rate <= 0:
159+
raise ValueError(f'Cooldown rate must be equal to or greater than 1. Got "{rate}" expected >= 1.')
160+
161+
self._rate: int = rate
162+
self._per: datetime.timedelta = datetime.timedelta(seconds=per) if isinstance(per, int) else per
163+
164+
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
165+
self._window: datetime.datetime = now + self._per
166+
167+
if self._window <= now:
168+
raise ValueError("The provided per value for Cooldowns can not go into the past.")
169+
170+
self._tokens: int = self._rate
171+
self.last_updated: datetime.datetime | None = None
172+
173+
@property
174+
def per(self) -> datetime.timedelta:
175+
return self._per
176+
177+
def reset(self) -> None:
178+
self._tokens = self._rate
179+
self._window = datetime.datetime.now(tz=datetime.UTC) + self._per
180+
181+
def get_tokens(self, now: datetime.datetime | None = None) -> int:
182+
if now is None:
183+
now = datetime.datetime.now(tz=datetime.UTC)
184+
185+
tokens = max(self._tokens, 0)
186+
if now > self._window:
187+
tokens = self._rate
188+
189+
return tokens
190+
191+
def is_ratelimited(self) -> bool:
192+
self._tokens = self.get_tokens()
193+
return self._tokens == 0
194+
195+
def update(self, *, factor: int = 1) -> float | None:
196+
now = datetime.datetime.now(tz=datetime.UTC)
197+
self.last_updated = now
198+
199+
self._tokens = self.get_tokens(now)
200+
201+
if self._tokens == self._rate:
202+
self._window = datetime.datetime.now(tz=datetime.UTC) + self._per
203+
204+
self._tokens -= factor
205+
206+
if self._tokens < 0:
207+
remaining = (self._window - now).total_seconds()
208+
return remaining
209+
210+
def copy(self) -> Self:
211+
return self.__class__(rate=self._rate, per=self._per)
212+
213+
def is_dead(self) -> bool:
214+
if self.last_updated is None:
215+
return False
216+
217+
now = datetime.datetime.now(tz=datetime.UTC)
218+
return now > (self.last_updated + self.per)
219+
220+
221+
class GCRACooldown(BaseCooldown):
222+
"""Not implemented yet."""
223+
224+
225+
KeyT: TypeAlias = Callable[[Any], Hashable] | Callable[[Any], Coroutine[Any, Any, Hashable]] | BucketType
226+
227+
228+
class Bucket(Generic[PT]):
229+
def __init__(self, cooldown: BaseCooldown, *, key: KeyT) -> None:
230+
self._cooldown: BaseCooldown = cooldown
231+
self._cache: dict[Hashable, BaseCooldown] = {}
232+
self._key: KeyT = key
233+
234+
@classmethod
235+
def from_cooldown(cls, *, base: type[BaseCooldown], key: KeyT, **kwargs: Any) -> Self:
236+
cd: BaseCooldown = base(**kwargs)
237+
return cls(cd, key=key)
238+
239+
def create_cooldown(self) -> BaseCooldown | None:
240+
return self._cooldown.copy()
241+
242+
def verify_cache(self) -> None:
243+
dead = [k for k, v in self._cache.items() if v.is_dead()]
244+
for key in dead:
245+
del self._cache[key]
246+
247+
async def get_key(self, payload: PT) -> Hashable:
248+
if asyncio.iscoroutinefunction(self._key):
249+
key = await self._key(payload)
250+
else:
251+
key = self._key(payload) # type: ignore
252+
253+
return key
254+
255+
async def get_cooldown(self, payload: PT) -> BaseCooldown | None:
256+
if self._key is BucketType.default:
257+
return self._cooldown
258+
259+
self.verify_cache()
260+
key = await self.get_key(payload)
261+
if key is None:
262+
return
263+
264+
if key not in self._cache:
265+
cooldown = self.create_cooldown()
266+
267+
if cooldown is not None:
268+
self._cache[key] = cooldown
269+
else:
270+
cooldown = self._cache[key]
271+
272+
return cooldown
273+
274+
async def update(self, payload: PT, **kwargs: Any) -> float | None:
275+
bucket = await self.get_cooldown(payload)
276+
277+
if bucket is None:
278+
return None
279+
280+
return bucket.update(**kwargs)

0 commit comments

Comments
 (0)