@@ -2878,7 +2878,9 @@ async def update_shard_count(self, shard_count: int, /, *, assign_transports: bo
28782878 await self ._client ._associate_shards (shard_ids )
28792879 elif self .shard_count < len (self .websockets ):
28802880 remove = len (self .websockets ) - self .shard_count
2881- await self ._disassociate_shards (self .shard_count , remove )
2881+
2882+ async with self ._client ._associate_lock :
2883+ await self ._disassociate_shards (self .shard_count , remove )
28822884
28832885 return self
28842886
@@ -3249,12 +3251,65 @@ def __init__(
32493251
32503252 self ._conduit_info : ConduitInfo = ConduitInfo (self )
32513253 self ._closing : bool = False
3254+ self ._background_check_task : asyncio .Task [None ] | None = None
3255+ self ._associate_lock : asyncio .Lock = asyncio .Lock ()
32523256
32533257 super ().__init__ (client_id = client_id , client_secret = client_secret , bot_id = bot_id , ** kwargs )
32543258
32553259 def __repr__ (self ) -> str :
32563260 return self .__class__ .__name__
32573261
3262+ async def _conduit_check (self ) -> None :
3263+ while True :
3264+ await asyncio .sleep (120 )
3265+ logger .debug ("Checking status of Conduit assigned to %r" , self )
3266+
3267+ try :
3268+ conduits = await self .fetch_conduits ()
3269+ except Exception as e :
3270+ logger .debug ("Exception received fetching Conduits during Conduit checK: %s. Disregarding..." , e )
3271+ continue
3272+
3273+ conduit : Conduit | None = None
3274+
3275+ for c in conduits :
3276+ if c .id == self .conduit_info .id :
3277+ conduit = c
3278+
3279+ if not conduit :
3280+ logger .debug ("No conduit found during Conduit check. Disregarding..." )
3281+ continue
3282+
3283+ broken : list [int ] = []
3284+
3285+ try :
3286+ async for shard in conduit .fetch_shards ():
3287+ shard_id = int (shard .id )
3288+
3289+ if shard_id not in self ._shard_ids :
3290+ continue
3291+
3292+ if shard .callback :
3293+ continue
3294+
3295+ if shard .status .startswith ("webhook" ):
3296+ continue
3297+
3298+ if shard .status != "enabled" :
3299+ broken .append (shard_id )
3300+ except Exception as e :
3301+ logger .debug ("Exception received fetching Conduit Shards during Conduit checK: %s. Disregarding..." , e )
3302+ continue
3303+
3304+ if not broken :
3305+ continue
3306+
3307+ logger .debug ("Potentially broken shards found during Conduit check. Trying to re-associate: %r" , broken )
3308+ try :
3309+ await self ._associate_shards (broken )
3310+ except Exception :
3311+ logger .warning ("An attempt to re-associate Conduit Shards: %r was unsuccessful. Consider rebalancing." )
3312+
32583313 async def _setup (self ) -> None :
32593314 # Subscribe to "conduit.shard.disabled"
32603315
@@ -3301,7 +3356,9 @@ async def _setup(self) -> None:
33013356
33023357 await self ._associate_shards (self ._shard_ids )
33033358 await self .setup_hook ()
3359+
33043360 self ._setup_called = True
3361+ self ._background_check_task = asyncio .create_task (self ._conduit_check ())
33053362
33063363 async def _websocket_closed (self , payload : WebsocketClosed ) -> None :
33073364 if self ._closing :
@@ -3377,22 +3434,29 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
33773434 logger .info ("Associated shards with %r successfully." , self ._conduit_info )
33783435
33793436 async def _associate_shards (self , shard_ids : list [int ]) -> None :
3380- assert self ._conduit_info . conduit
3437+ await self ._associate_lock . acquire ()
33813438
3382- batched : list [Websocket ] = []
3439+ try :
3440+ assert self ._conduit_info .conduit
33833441
3384- for i , n in enumerate (shard_ids ):
3385- if i % 10 == 0 and i != 0 :
3386- await self ._process_batched (batched )
3387- batched .clear ()
3442+ batched : list [Websocket ] = []
3443+
3444+ for i , n in enumerate (shard_ids ):
3445+ if i % 10 == 0 and i != 0 :
3446+ await self ._process_batched (batched )
3447+ batched .clear ()
33883448
3389- websocket = Websocket (client = self , http = self ._http , shard_id = str (n ))
3390- batched .append (websocket )
3449+ websocket = Websocket (client = self , http = self ._http , shard_id = str (n ))
3450+ batched .append (websocket )
33913451
3392- if batched :
3393- await self ._process_batched (batched )
3452+ if batched :
3453+ await self ._process_batched (batched )
33943454
3395- self ._shard_ids = sorted ([int (k ) for k in self ._conduit_info ._sockets ])
3455+ self ._shard_ids = sorted ([int (k ) for k in self ._conduit_info ._sockets ])
3456+ except :
3457+ raise
3458+ finally :
3459+ self ._associate_lock .release ()
33963460
33973461 async def _generate_new_conduit (self ) -> Conduit :
33983462 if not self ._shard_ids :
@@ -3548,6 +3612,13 @@ async def close(self, **options: Any) -> None:
35483612 await socket .close ()
35493613
35503614 self ._conduit_info ._sockets .clear ()
3615+
3616+ if self ._background_check_task :
3617+ try :
3618+ self ._background_check_task .cancel ()
3619+ except Exception :
3620+ pass
3621+
35513622 await super ().close (** options )
35523623
35533624 async def delete_websocket_subscription (self , * args : Any , ** kwargs : Any ) -> Any :
0 commit comments