Skip to content

Commit 0571244

Browse files
LumabotsDorukyum
andauthored
fix: correct generic return type in component utils (#2796)
* Update CHANGELOG.md Signed-off-by: Lumouille <[email protected]> * sync with pycord * fix: update Item type hints to Item[View] in view.py * fix: update Item type hints to use TypeVar[V] in view.py --------- Signed-off-by: Lumouille <[email protected]> Co-authored-by: Dorukyum <[email protected]>
1 parent aeacfed commit 0571244

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

discord/ui/view.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
import traceback
3333
from functools import partial
3434
from itertools import groupby
35-
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence
35+
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Iterator, Sequence, TypeVar
3636

3737
from ..components import ActionRow as ActionRowComponent
3838
from ..components import Button as ButtonComponent
@@ -51,6 +51,8 @@
5151
from ..state import ConnectionState
5252
from ..types.components import Component as ComponentPayload
5353

54+
V = TypeVar("V", bound="View", covariant=True)
55+
5456

5557
def _walk_all_components(components: list[Component]) -> Iterator[Component]:
5658
for item in components:
@@ -60,7 +62,7 @@ def _walk_all_components(components: list[Component]) -> Iterator[Component]:
6062
yield item
6163

6264

63-
def _component_to_item(component: Component) -> Item:
65+
def _component_to_item(component: Component) -> Item[V]:
6466
if isinstance(component, ButtonComponent):
6567
from .button import Button
6668

@@ -75,7 +77,7 @@ def _component_to_item(component: Component) -> Item:
7577
class _ViewWeights:
7678
__slots__ = ("weights",)
7779

78-
def __init__(self, children: list[Item]):
80+
def __init__(self, children: list[Item[V]]):
7981
self.weights: list[int] = [0, 0, 0, 0, 0]
8082

8183
key = lambda i: sys.maxsize if i.row is None else i.row
@@ -84,14 +86,14 @@ def __init__(self, children: list[Item]):
8486
for item in group:
8587
self.add_item(item)
8688

87-
def find_open_space(self, item: Item) -> int:
89+
def find_open_space(self, item: Item[V]) -> int:
8890
for index, weight in enumerate(self.weights):
8991
if weight + item.width <= 5:
9092
return index
9193

9294
raise ValueError("could not find open space for item")
9395

94-
def add_item(self, item: Item) -> None:
96+
def add_item(self, item: Item[V]) -> None:
9597
if item.row is not None:
9698
total = self.weights[item.row] + item.width
9799
if total > 5:
@@ -105,7 +107,7 @@ def add_item(self, item: Item) -> None:
105107
self.weights[index] += item.width
106108
item._rendered_row = index
107109

108-
def remove_item(self, item: Item) -> None:
110+
def remove_item(self, item: Item[V]) -> None:
109111
if item._rendered_row is not None:
110112
self.weights[item._rendered_row] -= item.width
111113
item._rendered_row = None
@@ -163,15 +165,15 @@ def __init_subclass__(cls) -> None:
163165

164166
def __init__(
165167
self,
166-
*items: Item,
168+
*items: Item[V],
167169
timeout: float | None = 180.0,
168170
disable_on_timeout: bool = False,
169171
):
170172
self.timeout = timeout
171173
self.disable_on_timeout = disable_on_timeout
172-
self.children: list[Item] = []
174+
self.children: list[Item[V]] = []
173175
for func in self.__view_children_items__:
174-
item: Item = func.__discord_ui_model_type__(
176+
item: Item[V] = func.__discord_ui_model_type__(
175177
**func.__discord_ui_model_kwargs__
176178
)
177179
item.callback = partial(func, self, item)
@@ -213,7 +215,7 @@ async def __timeout_task_impl(self) -> None:
213215
await asyncio.sleep(self.__timeout_expiry - now)
214216

215217
def to_components(self) -> list[dict[str, Any]]:
216-
def key(item: Item) -> int:
218+
def key(item: Item[V]) -> int:
217219
return item._rendered_row or 0
218220

219221
children = sorted(self.children, key=key)
@@ -267,7 +269,7 @@ def _expires_at(self) -> float | None:
267269
return time.monotonic() + self.timeout
268270
return None
269271

270-
def add_item(self, item: Item) -> None:
272+
def add_item(self, item: Item[V]) -> None:
271273
"""Adds an item to the view.
272274
273275
Parameters
@@ -295,7 +297,7 @@ def add_item(self, item: Item) -> None:
295297
item._view = self
296298
self.children.append(item)
297299

298-
def remove_item(self, item: Item) -> None:
300+
def remove_item(self, item: Item[V]) -> None:
299301
"""Removes an item from the view.
300302
301303
Parameters
@@ -316,7 +318,7 @@ def clear_items(self) -> None:
316318
self.children.clear()
317319
self.__weights.clear()
318320

319-
def get_item(self, custom_id: str) -> Item | None:
321+
def get_item(self, custom_id: str) -> Item[V] | None:
320322
"""Get an item from the view with the given custom ID. Alias for `utils.get(view.children, custom_id=custom_id)`.
321323
322324
Parameters
@@ -391,7 +393,7 @@ async def on_check_failure(self, interaction: Interaction) -> None:
391393
"""
392394

393395
async def on_error(
394-
self, error: Exception, item: Item, interaction: Interaction
396+
self, error: Exception, item: Item[V], interaction: Interaction
395397
) -> None:
396398
"""|coro|
397399
@@ -414,7 +416,7 @@ async def on_error(
414416
error.__class__, error, error.__traceback__, file=sys.stderr
415417
)
416418

417-
async def _scheduled_task(self, item: Item, interaction: Interaction):
419+
async def _scheduled_task(self, item: Item[V], interaction: Interaction):
418420
try:
419421
if self.timeout:
420422
self.__timeout_expiry = time.monotonic() + self.timeout
@@ -446,7 +448,7 @@ def _dispatch_timeout(self):
446448
self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}"
447449
)
448450

449-
def _dispatch_item(self, item: Item, interaction: Interaction):
451+
def _dispatch_item(self, item: Item[V], interaction: Interaction):
450452
if self.__stopped.done():
451453
return
452454

@@ -460,10 +462,10 @@ def _dispatch_item(self, item: Item, interaction: Interaction):
460462

461463
def refresh(self, components: list[Component]):
462464
# This is pretty hacky at the moment
463-
old_state: dict[tuple[int, str], Item] = {
465+
old_state: dict[tuple[int, str], Item[V]] = {
464466
(item.type.value, item.custom_id): item for item in self.children if item.is_dispatchable() # type: ignore
465467
}
466-
children: list[Item] = [
468+
children: list[Item[V]] = [
467469
item for item in self.children if not item.is_dispatchable()
468470
]
469471
for component in _walk_all_components(components):
@@ -529,7 +531,7 @@ async def wait(self) -> bool:
529531
"""
530532
return await self.__stopped
531533

532-
def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
534+
def disable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None:
533535
"""
534536
Disables all items in the view.
535537
@@ -542,7 +544,7 @@ def disable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
542544
if exclusions is None or child not in exclusions:
543545
child.disabled = True
544546

545-
def enable_all_items(self, *, exclusions: list[Item] | None = None) -> None:
547+
def enable_all_items(self, *, exclusions: list[Item[V]] | None = None) -> None:
546548
"""
547549
Enables all items in the view.
548550
@@ -567,7 +569,7 @@ def message(self, value):
567569
class ViewStore:
568570
def __init__(self, state: ConnectionState):
569571
# (component_type, message_id, custom_id): (View, Item)
570-
self._views: dict[tuple[int, int | None, str], tuple[View, Item]] = {}
572+
self._views: dict[tuple[int, int | None, str], tuple[View, Item[V]]] = {}
571573
# message_id: View
572574
self._synced_message_views: dict[int, View] = {}
573575
self._state: ConnectionState = state

0 commit comments

Comments
 (0)