@@ -100,7 +100,7 @@ def reset(self) -> None:
100100 raise NotImplementedError
101101
102102 @abc .abstractmethod
103- def update (self ) -> float | None :
103+ def update (self , * args : Any , ** kwargs : Any ) -> float | None :
104104 """Base method which should be implemented to update the cooldown/ratelimit.
105105
106106 This is where your algorithm logic should be contained.
@@ -125,7 +125,7 @@ def copy(self) -> Self:
125125 raise NotImplementedError
126126
127127 @abc .abstractmethod
128- def is_ratelimited (self ) -> bool :
128+ def is_ratelimited (self , * args : Any , ** kwargs : Any ) -> bool :
129129 """Base method which should be implemented which returns a bool indicating whether the cooldown is ratelimited.
130130
131131 Returns
@@ -136,7 +136,7 @@ def is_ratelimited(self) -> bool:
136136 raise NotImplementedError
137137
138138 @abc .abstractmethod
139- def is_dead (self ) -> bool :
139+ def is_dead (self , * args : Any , ** kwargs : Any ) -> bool :
140140 """Base method which should be implemented to indicate whether the cooldown should be considered stale and allowed
141141 to be removed from the ``bucket: cooldown`` mapping.
142142
@@ -154,12 +154,12 @@ class Cooldown(BaseCooldown):
154154 See: :func:`~.commands.cooldown` for more documentation.
155155 """
156156
157- def __init__ (self , * , rate : int , per : int | datetime .timedelta ) -> None :
157+ def __init__ (self , * , rate : int , per : float | datetime .timedelta ) -> None :
158158 if rate <= 0 :
159159 raise ValueError (f'Cooldown rate must be equal to or greater than 1. Got "{ rate } " expected >= 1.' )
160160
161161 self ._rate : int = rate
162- self ._per : datetime .timedelta = datetime .timedelta (seconds = per ) if isinstance (per , int ) else per
162+ self ._per : datetime .timedelta = datetime .timedelta (seconds = per ) if not isinstance (per , datetime . timedelta ) else per
163163
164164 now : datetime .datetime = datetime .datetime .now (tz = datetime .UTC )
165165 self ._window : datetime .datetime = now + self ._per
@@ -219,7 +219,71 @@ def is_dead(self) -> bool:
219219
220220
221221class GCRACooldown (BaseCooldown ):
222- """Not implemented yet."""
222+ """GCRA cooldown algorithm for :func:`~.commands.cooldown`, which implements the ``GCRA`` ratelimiting algorithm.
223+
224+ See: :func:`~.commands.cooldown` for more documentation.
225+ """
226+
227+ def __init__ (self , * , rate : int , per : float | datetime .timedelta ) -> None :
228+ if rate <= 0 :
229+ raise ValueError (f'Cooldown rate must be equal to or greater than 1. Got "{ rate } " expected >= 1.' )
230+
231+ self ._rate : int = rate
232+ self ._per : datetime .timedelta = datetime .timedelta (seconds = per ) if not isinstance (per , datetime .timedelta ) else per
233+
234+ now : datetime .datetime = datetime .datetime .now (tz = datetime .UTC )
235+ self ._tat : datetime .datetime | None = None
236+
237+ if (now + self ._per ) <= now :
238+ raise ValueError ("The provided per value for Cooldowns can not go into the past." )
239+
240+ self .last_updated : datetime .datetime | None = None
241+
242+ @property
243+ def inverse (self ) -> float :
244+ return self ._per .total_seconds () / self ._rate
245+
246+ @property
247+ def per (self ) -> datetime .timedelta :
248+ return self ._per
249+
250+ def reset (self ) -> None :
251+ self .last_updated = None
252+ self ._tat = None
253+
254+ def is_ratelimited (self , * , now : datetime .datetime | None = None ) -> bool :
255+ now = datetime .datetime .now (tz = datetime .UTC ) or now
256+ tat : datetime .datetime = max (self ._tat or now , now )
257+
258+ separation : float = (tat - now ).total_seconds ()
259+ max_interval : float = self ._per .total_seconds () - self .inverse
260+
261+ return separation > max_interval
262+
263+ def update (self ) -> float | None :
264+ now : datetime .datetime = datetime .datetime .now (tz = datetime .UTC )
265+ tat : datetime .datetime = max (self ._tat or now , now )
266+
267+ self .last_updated = now
268+
269+ separation : float = (tat - now ).total_seconds ()
270+ max_interval : float = self ._per .total_seconds () - self .inverse
271+
272+ if separation > max_interval :
273+ return separation - max_interval
274+
275+ new = max (tat , now ) + datetime .timedelta (seconds = self .inverse )
276+ self ._tat = new
277+
278+ def copy (self ) -> Self :
279+ return self .__class__ (rate = self ._rate , per = self ._per )
280+
281+ def is_dead (self ) -> bool :
282+ if self .last_updated is None :
283+ return False
284+
285+ now = datetime .datetime .now (tz = datetime .UTC )
286+ return now > (self .last_updated + self .per )
223287
224288
225289KeyT : TypeAlias = Callable [[Any ], Hashable ] | Callable [[Any ], Coroutine [Any , Any , Hashable ]] | BucketType
0 commit comments