32
32
import traceback
33
33
from functools import partial
34
34
from itertools import groupby
35
- from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence
35
+ from typing import TYPE_CHECKING , Any , Callable , ClassVar , Iterator , Sequence , TypeVar
36
36
37
37
from ..components import ActionRow as ActionRowComponent
38
38
from ..components import Button as ButtonComponent
51
51
from ..state import ConnectionState
52
52
from ..types .components import Component as ComponentPayload
53
53
54
+ V = TypeVar ("V" , bound = "View" , covariant = True )
55
+
54
56
55
57
def _walk_all_components (components : list [Component ]) -> Iterator [Component ]:
56
58
for item in components :
@@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
60
62
yield item
61
63
62
64
63
- def _component_to_item (component : Component ) -> Item :
65
+ def _component_to_item (component : Component ) -> Item [ V ] :
64
66
if isinstance (component , ButtonComponent ):
65
67
from .button import Button
66
68
@@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item:
75
77
class _ViewWeights :
76
78
__slots__ = ("weights" ,)
77
79
78
- def __init__ (self , children : list [Item ]):
80
+ def __init__ (self , children : list [Item [ V ] ]):
79
81
self .weights : list [int ] = [0 , 0 , 0 , 0 , 0 ]
80
82
81
83
key = lambda i : sys .maxsize if i .row is None else i .row
@@ -84,14 +86,14 @@ def __init__(self, children: list[Item]):
84
86
for item in group :
85
87
self .add_item (item )
86
88
87
- def find_open_space (self , item : Item ) -> int :
89
+ def find_open_space (self , item : Item [ V ] ) -> int :
88
90
for index , weight in enumerate (self .weights ):
89
91
if weight + item .width <= 5 :
90
92
return index
91
93
92
94
raise ValueError ("could not find open space for item" )
93
95
94
- def add_item (self , item : Item ) -> None :
96
+ def add_item (self , item : Item [ V ] ) -> None :
95
97
if item .row is not None :
96
98
total = self .weights [item .row ] + item .width
97
99
if total > 5 :
@@ -105,7 +107,7 @@ def add_item(self, item: Item) -> None:
105
107
self .weights [index ] += item .width
106
108
item ._rendered_row = index
107
109
108
- def remove_item (self , item : Item ) -> None :
110
+ def remove_item (self , item : Item [ V ] ) -> None :
109
111
if item ._rendered_row is not None :
110
112
self .weights [item ._rendered_row ] -= item .width
111
113
item ._rendered_row = None
@@ -163,15 +165,15 @@ def __init_subclass__(cls) -> None:
163
165
164
166
def __init__ (
165
167
self ,
166
- * items : Item ,
168
+ * items : Item [ V ] ,
167
169
timeout : float | None = 180.0 ,
168
170
disable_on_timeout : bool = False ,
169
171
):
170
172
self .timeout = timeout
171
173
self .disable_on_timeout = disable_on_timeout
172
- self .children : list [Item ] = []
174
+ self .children : list [Item [ V ] ] = []
173
175
for func in self .__view_children_items__ :
174
- item : Item = func .__discord_ui_model_type__ (
176
+ item : Item [ V ] = func .__discord_ui_model_type__ (
175
177
** func .__discord_ui_model_kwargs__
176
178
)
177
179
item .callback = partial (func , self , item )
@@ -213,7 +215,7 @@ async def __timeout_task_impl(self) -> None:
213
215
await asyncio .sleep (self .__timeout_expiry - now )
214
216
215
217
def to_components (self ) -> list [dict [str , Any ]]:
216
- def key (item : Item ) -> int :
218
+ def key (item : Item [ V ] ) -> int :
217
219
return item ._rendered_row or 0
218
220
219
221
children = sorted (self .children , key = key )
@@ -267,7 +269,7 @@ def _expires_at(self) -> float | None:
267
269
return time .monotonic () + self .timeout
268
270
return None
269
271
270
- def add_item (self , item : Item ) -> None :
272
+ def add_item (self , item : Item [ V ] ) -> None :
271
273
"""Adds an item to the view.
272
274
273
275
Parameters
@@ -295,7 +297,7 @@ def add_item(self, item: Item) -> None:
295
297
item ._view = self
296
298
self .children .append (item )
297
299
298
- def remove_item (self , item : Item ) -> None :
300
+ def remove_item (self , item : Item [ V ] ) -> None :
299
301
"""Removes an item from the view.
300
302
301
303
Parameters
@@ -316,7 +318,7 @@ def clear_items(self) -> None:
316
318
self .children .clear ()
317
319
self .__weights .clear ()
318
320
319
- def get_item (self , custom_id : str ) -> Item | None :
321
+ def get_item (self , custom_id : str ) -> Item [ V ] | None :
320
322
"""Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`.
321
323
322
324
Parameters
@@ -391,7 +393,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
391
393
"""
392
394
393
395
async def on_error (
394
- self , error : Exception , item : Item , interaction : Interaction
396
+ self , error : Exception , item : Item [ V ] , interaction : Interaction
395
397
) -> None :
396
398
"""|coro|
397
399
@@ -414,7 +416,7 @@ async def on_error(
414
416
error .__class__ , error , error .__traceback__ , file = sys .stderr
415
417
)
416
418
417
- async def _scheduled_task (self , item : Item , interaction : Interaction ):
419
+ async def _scheduled_task (self , item : Item [ V ] , interaction : Interaction ):
418
420
try :
419
421
if self .timeout :
420
422
self .__timeout_expiry = time .monotonic () + self .timeout
@@ -446,7 +448,7 @@ def _dispatch_timeout(self):
446
448
self .on_timeout (), name = f"discord-ui-view-timeout-{ self .id } "
447
449
)
448
450
449
- def _dispatch_item (self , item : Item , interaction : Interaction ):
451
+ def _dispatch_item (self , item : Item [ V ] , interaction : Interaction ):
450
452
if self .__stopped .done ():
451
453
return
452
454
@@ -460,10 +462,10 @@ def _dispatch_item(self, item: Item, interaction: Interaction):
460
462
461
463
def refresh (self , components : list [Component ]):
462
464
# This is pretty hacky at the moment
463
- old_state : dict [tuple [int , str ], Item ] = {
465
+ old_state : dict [tuple [int , str ], Item [ V ] ] = {
464
466
(item .type .value , item .custom_id ): item for item in self .children if item .is_dispatchable () # type: ignore
465
467
}
466
- children : list [Item ] = [
468
+ children : list [Item [ V ] ] = [
467
469
item for item in self .children if not item .is_dispatchable ()
468
470
]
469
471
for component in _walk_all_components (components ):
@@ -529,7 +531,7 @@ async def wait(self) -> bool:
529
531
"""
530
532
return await self .__stopped
531
533
532
- def disable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
534
+ def disable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
533
535
"""
534
536
Disables all items in the view.
535
537
@@ -542,7 +544,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
542
544
if exclusions is None or child not in exclusions :
543
545
child .disabled = True
544
546
545
- def enable_all_items (self , * , exclusions : list [Item ] | None = None ) -> None :
547
+ def enable_all_items (self , * , exclusions : list [Item [ V ] ] | None = None ) -> None :
546
548
"""
547
549
Enables all items in the view.
548
550
@@ -567,7 +569,7 @@ def message(self, value):
567
569
class ViewStore :
568
570
def __init__ (self , state : ConnectionState ):
569
571
# (component_type, message_id, custom_id): (View, Item)
570
- self ._views : dict [tuple [int , int | None , str ], tuple [View , Item ]] = {}
572
+ self ._views : dict [tuple [int , int | None , str ], tuple [View , Item [ V ] ]] = {}
571
573
# message_id: View
572
574
self ._synced_message_views : dict [int , View ] = {}
573
575
self ._state : ConnectionState = state
0 commit comments