Skip to content

Commit 8a3b310

Browse files
committed
Implement GCRA Cooldown algorithm
1 parent 3a05630 commit 8a3b310

File tree

1 file changed

+70
-6
lines changed

1 file changed

+70
-6
lines changed

twitchio/ext/commands/cooldowns.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def reset(self) -> None:
100100
raise NotImplementedError
101101

102102
@abc.abstractmethod
103-
def update(self) -> float | None:
103+
def update(self, *args: Any, **kwargs: Any) -> float | None:
104104
"""Base method which should be implemented to update the cooldown/ratelimit.
105105
106106
This is where your algorithm logic should be contained.
@@ -125,7 +125,7 @@ def copy(self) -> Self:
125125
raise NotImplementedError
126126

127127
@abc.abstractmethod
128-
def is_ratelimited(self) -> bool:
128+
def is_ratelimited(self, *args: Any, **kwargs: Any) -> bool:
129129
"""Base method which should be implemented which returns a bool indicating whether the cooldown is ratelimited.
130130
131131
Returns
@@ -136,7 +136,7 @@ def is_ratelimited(self) -> bool:
136136
raise NotImplementedError
137137

138138
@abc.abstractmethod
139-
def is_dead(self) -> bool:
139+
def is_dead(self, *args: Any, **kwargs: Any) -> bool:
140140
"""Base method which should be implemented to indicate whether the cooldown should be considered stale and allowed
141141
to be removed from the ``bucket: cooldown`` mapping.
142142
@@ -154,12 +154,12 @@ class Cooldown(BaseCooldown):
154154
See: :func:`~.commands.cooldown` for more documentation.
155155
"""
156156

157-
def __init__(self, *, rate: int, per: int | datetime.timedelta) -> None:
157+
def __init__(self, *, rate: int, per: float | datetime.timedelta) -> None:
158158
if rate <= 0:
159159
raise ValueError(f'Cooldown rate must be equal to or greater than 1. Got "{rate}" expected >= 1.')
160160

161161
self._rate: int = rate
162-
self._per: datetime.timedelta = datetime.timedelta(seconds=per) if isinstance(per, int) else per
162+
self._per: datetime.timedelta = datetime.timedelta(seconds=per) if not isinstance(per, datetime.timedelta) else per
163163

164164
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
165165
self._window: datetime.datetime = now + self._per
@@ -219,7 +219,71 @@ def is_dead(self) -> bool:
219219

220220

221221
class GCRACooldown(BaseCooldown):
222-
"""Not implemented yet."""
222+
"""GCRA cooldown algorithm for :func:`~.commands.cooldown`, which implements the ``GCRA`` ratelimiting algorithm.
223+
224+
See: :func:`~.commands.cooldown` for more documentation.
225+
"""
226+
227+
def __init__(self, *, rate: int, per: float | datetime.timedelta) -> None:
228+
if rate <= 0:
229+
raise ValueError(f'Cooldown rate must be equal to or greater than 1. Got "{rate}" expected >= 1.')
230+
231+
self._rate: int = rate
232+
self._per: datetime.timedelta = datetime.timedelta(seconds=per) if not isinstance(per, datetime.timedelta) else per
233+
234+
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
235+
self._tat: datetime.datetime | None = None
236+
237+
if (now + self._per) <= now:
238+
raise ValueError("The provided per value for Cooldowns can not go into the past.")
239+
240+
self.last_updated: datetime.datetime | None = None
241+
242+
@property
243+
def inverse(self) -> float:
244+
return self._per.total_seconds() / self._rate
245+
246+
@property
247+
def per(self) -> datetime.timedelta:
248+
return self._per
249+
250+
def reset(self) -> None:
251+
self.last_updated = None
252+
self._tat = None
253+
254+
def is_ratelimited(self, *, now: datetime.datetime | None = None) -> bool:
255+
now = datetime.datetime.now(tz=datetime.UTC) or now
256+
tat: datetime.datetime = max(self._tat or now, now)
257+
258+
separation: float = (tat - now).total_seconds()
259+
max_interval: float = self._per.total_seconds() - self.inverse
260+
261+
return separation > max_interval
262+
263+
def update(self) -> float | None:
264+
now: datetime.datetime = datetime.datetime.now(tz=datetime.UTC)
265+
tat: datetime.datetime = max(self._tat or now, now)
266+
267+
self.last_updated = now
268+
269+
separation: float = (tat - now).total_seconds()
270+
max_interval: float = self._per.total_seconds() - self.inverse
271+
272+
if separation > max_interval:
273+
return separation - max_interval
274+
275+
new = max(tat, now) + datetime.timedelta(seconds=self.inverse)
276+
self._tat = new
277+
278+
def copy(self) -> Self:
279+
return self.__class__(rate=self._rate, per=self._per)
280+
281+
def is_dead(self) -> bool:
282+
if self.last_updated is None:
283+
return False
284+
285+
now = datetime.datetime.now(tz=datetime.UTC)
286+
return now > (self.last_updated + self.per)
223287

224288

225289
KeyT: TypeAlias = Callable[[Any], Hashable] | Callable[[Any], Coroutine[Any, Any, Hashable]] | BucketType

0 commit comments

Comments
 (0)