Skip to content

Commit 4171274

Browse files
committed
Add extra checks and safety to Conduit shards
1 parent 8903d3b commit 4171274

File tree

1 file changed

+83
-12
lines changed

1 file changed

+83
-12
lines changed

twitchio/client.py

Lines changed: 83 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2878,7 +2878,9 @@ async def update_shard_count(self, shard_count: int, /, *, assign_transports: bo
28782878
await self._client._associate_shards(shard_ids)
28792879
elif self.shard_count < len(self.websockets):
28802880
remove = len(self.websockets) - self.shard_count
2881-
await self._disassociate_shards(self.shard_count, remove)
2881+
2882+
async with self._client._associate_lock:
2883+
await self._disassociate_shards(self.shard_count, remove)
28822884

28832885
return self
28842886

@@ -3249,12 +3251,65 @@ def __init__(
32493251

32503252
self._conduit_info: ConduitInfo = ConduitInfo(self)
32513253
self._closing: bool = False
3254+
self._background_check_task: asyncio.Task[None] | None = None
3255+
self._associate_lock: asyncio.Lock = asyncio.Lock()
32523256

32533257
super().__init__(client_id=client_id, client_secret=client_secret, bot_id=bot_id, **kwargs)
32543258

32553259
def __repr__(self) -> str:
32563260
return self.__class__.__name__
32573261

3262+
async def _conduit_check(self) -> None:
3263+
while True:
3264+
await asyncio.sleep(120)
3265+
logger.debug("Checking status of Conduit assigned to %r", self)
3266+
3267+
try:
3268+
conduits = await self.fetch_conduits()
3269+
except Exception as e:
3270+
logger.debug("Exception received fetching Conduits during Conduit checK: %s. Disregarding...", e)
3271+
continue
3272+
3273+
conduit: Conduit | None = None
3274+
3275+
for c in conduits:
3276+
if c.id == self.conduit_info.id:
3277+
conduit = c
3278+
3279+
if not conduit:
3280+
logger.debug("No conduit found during Conduit check. Disregarding...")
3281+
continue
3282+
3283+
broken: list[int] = []
3284+
3285+
try:
3286+
async for shard in conduit.fetch_shards():
3287+
shard_id = int(shard.id)
3288+
3289+
if shard_id not in self._shard_ids:
3290+
continue
3291+
3292+
if shard.callback:
3293+
continue
3294+
3295+
if shard.status.startswith("webhook"):
3296+
continue
3297+
3298+
if shard.status != "enabled":
3299+
broken.append(shard_id)
3300+
except Exception as e:
3301+
logger.debug("Exception received fetching Conduit Shards during Conduit checK: %s. Disregarding...", e)
3302+
continue
3303+
3304+
if not broken:
3305+
continue
3306+
3307+
logger.debug("Potentially broken shards found during Conduit check. Trying to re-associate: %r", broken)
3308+
try:
3309+
await self._associate_shards(broken)
3310+
except Exception:
3311+
logger.warning("An attempt to re-associate Conduit Shards: %r was unsuccessful. Consider rebalancing.")
3312+
32583313
async def _setup(self) -> None:
32593314
# Subscribe to "conduit.shard.disabled"
32603315

@@ -3301,7 +3356,9 @@ async def _setup(self) -> None:
33013356

33023357
await self._associate_shards(self._shard_ids)
33033358
await self.setup_hook()
3359+
33043360
self._setup_called = True
3361+
self._background_check_task = asyncio.create_task(self._conduit_check())
33053362

33063363
async def _websocket_closed(self, payload: WebsocketClosed) -> None:
33073364
if self._closing:
@@ -3377,22 +3434,29 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
33773434
logger.info("Associated shards with %r successfully.", self._conduit_info)
33783435

33793436
async def _associate_shards(self, shard_ids: list[int]) -> None:
3380-
assert self._conduit_info.conduit
3437+
await self._associate_lock.acquire()
33813438

3382-
batched: list[Websocket] = []
3439+
try:
3440+
assert self._conduit_info.conduit
33833441

3384-
for i, n in enumerate(shard_ids):
3385-
if i % 10 == 0 and i != 0:
3386-
await self._process_batched(batched)
3387-
batched.clear()
3442+
batched: list[Websocket] = []
3443+
3444+
for i, n in enumerate(shard_ids):
3445+
if i % 10 == 0 and i != 0:
3446+
await self._process_batched(batched)
3447+
batched.clear()
33883448

3389-
websocket = Websocket(client=self, http=self._http, shard_id=str(n))
3390-
batched.append(websocket)
3449+
websocket = Websocket(client=self, http=self._http, shard_id=str(n))
3450+
batched.append(websocket)
33913451

3392-
if batched:
3393-
await self._process_batched(batched)
3452+
if batched:
3453+
await self._process_batched(batched)
33943454

3395-
self._shard_ids = sorted([int(k) for k in self._conduit_info._sockets])
3455+
self._shard_ids = sorted([int(k) for k in self._conduit_info._sockets])
3456+
except:
3457+
raise
3458+
finally:
3459+
self._associate_lock.release()
33963460

33973461
async def _generate_new_conduit(self) -> Conduit:
33983462
if not self._shard_ids:
@@ -3548,6 +3612,13 @@ async def close(self, **options: Any) -> None:
35483612
await socket.close()
35493613

35503614
self._conduit_info._sockets.clear()
3615+
3616+
if self._background_check_task:
3617+
try:
3618+
self._background_check_task.cancel()
3619+
except Exception:
3620+
pass
3621+
35513622
await super().close(**options)
35523623

35533624
async def delete_websocket_subscription(self, *args: Any, **kwargs: Any) -> Any:

0 commit comments

Comments
 (0)