@@ -3209,7 +3209,10 @@ async def main() -> None:
32093209 force_subscribe: bool
32103210 An optional :class:`bool` which when ``True`` will force attempt to subscribe to the subscriptions provided in the
32113211 ``subscriptions`` parameter, regardless of whether a new conduit was created or not. Defaults to ``False``.
3212-
3212+ force_scale: bool
3213+ An optional :class:`bool` which when ``True`` will force the :class:`~twitchio.Conduit` associated with the
3214+ AutoClient/Bot to scale up/down to the provided amount of shards in the ``shard_ids`` parameter if provided. If the
3215+ ``shard_ids`` parameter is not passed, this parameter has no effect. Defaults to ``False``.
32133216 """
32143217
32153218 # NOTE:
@@ -3228,8 +3231,10 @@ def __init__(
32283231 ** kwargs : Unpack [AutoClientOptions ],
32293232 ) -> None :
32303233 self ._shard_ids : list [int ] = kwargs .pop ("shard_ids" , [])
3234+ self ._original_shards = self ._shard_ids
32313235 self ._conduit_id : str | bool | None = kwargs .pop ("conduit_id" , MISSING )
32323236 self ._force_sub : bool = kwargs .pop ("force_subscribe" , False )
3237+ self ._force_scale : bool = kwargs .pop ("force_scale" , False )
32333238 self ._subbed : bool = False
32343239
32353240 if self ._conduit_id is MISSING or self ._conduit_id is None :
@@ -3358,6 +3363,10 @@ async def _setup(self) -> None:
33583363 # TODO: Maybe log currernt conduit info?
33593364 raise MissingConduit ("No conduit could be found with the provided ID or a new one can not be created." )
33603365
3366+ if self ._force_scale and self ._original_shards :
3367+ logger .info ("Scaling %r to %d shards." , len (self ._original_shards ))
3368+ await self ._conduit_info .update_shard_count (len (self ._original_shards ), assign_transports = False )
3369+
33613370 await self ._associate_shards (self ._shard_ids )
33623371 if self ._force_sub and not self ._subbed :
33633372 await self .multi_subscribe (self ._initial_subs )
@@ -3438,7 +3447,12 @@ async def _process_batched(self, batched: list[Websocket]) -> None:
34383447 await self ._conduit_info ._update_shards (payloads )
34393448 self ._conduit_info ._sockets .update ({str (socket ._shard_id ): socket for socket in batched })
34403449
3441- logger .info ("Associated shards with %r successfully." , self ._conduit_info )
3450+ logger .info (
3451+ "Associated shards with %r successfully. Shards: %d / %d (connected / Conduit total)." ,
3452+ self ._conduit_info ,
3453+ len (self ._conduit_info .websockets ),
3454+ self ._conduit_info .shard_count ,
3455+ )
34423456
34433457 async def _associate_shards (self , shard_ids : list [int ]) -> None :
34443458 await self ._associate_lock .acquire ()
@@ -3626,14 +3640,25 @@ async def multi_subscribe(
36263640 )
36273641 return task
36283642
3643+ async def _close_sockets (self ) -> None :
3644+ socks = self ._conduit_info ._sockets .values ()
3645+ logger .info ("Attempting to close %d associated Conduit Websockets." , len (socks ))
3646+
3647+ tasks : list [asyncio .Task [None ]] = [asyncio .create_task (s .close ()) for s in socks ]
3648+ await asyncio .wait (tasks )
3649+
3650+ logger .info ("Successfully closed %d Conduit Websockets on %r." , len (socks ), self )
3651+
36293652 async def close (self , ** options : Any ) -> None :
36303653 if self ._closing :
36313654 return
36323655
36333656 self ._closing = True
36343657
3635- for socket in self ._conduit_info .websockets .values ():
3636- await socket .close ()
3658+ try :
3659+ await self ._close_sockets ()
3660+ except Exception as e :
3661+ logger .warning ("An error occurred during the cleanup of Conduit Websockets: %s" , e )
36373662
36383663 self ._conduit_info ._sockets .clear ()
36393664
0 commit comments