11import asyncio
22import functools
3+ import logging
34import random
45import socket
56import sys
@@ -131,6 +132,14 @@ def __del__(self) -> None:
131132 )
132133
133134
135+ async def _wait_for_close (waiters : List [Awaitable [object ]]) -> None :
136+ """Wait for all waiters to finish closing."""
137+ results = await asyncio .gather (* waiters , return_exceptions = True )
138+ for res in results :
139+ if isinstance (res , Exception ):
140+ logging .error ("Error while closing connector: %r" , res )
141+
142+
134143class Connection :
135144
136145 _source_traceback = None
@@ -222,10 +231,14 @@ def closed(self) -> bool:
222231class _TransportPlaceholder :
223232 """placeholder for BaseConnector.connect function"""
224233
225- __slots__ = ()
234+ __slots__ = ("closed" ,)
235+
236+ def __init__ (self , closed_future : asyncio .Future [Optional [Exception ]]) -> None :
237+ """Initialize a placeholder for a transport."""
238+ self .closed = closed_future
226239
227240 def close (self ) -> None :
228- """Close the placeholder transport ."""
241+ """Close the placeholder."""
229242
230243
231244class BaseConnector :
@@ -322,6 +335,10 @@ def __init__(
322335
323336 self ._cleanup_closed_disabled = not enable_cleanup_closed
324337 self ._cleanup_closed_transports : List [Optional [asyncio .Transport ]] = []
338+ self ._placeholder_future : asyncio .Future [Optional [Exception ]] = (
339+ loop .create_future ()
340+ )
341+ self ._placeholder_future .set_result (None )
325342 self ._cleanup_closed ()
326343
327344 def __del__ (self , _warnings : Any = warnings ) -> None :
@@ -454,18 +471,30 @@ def _cleanup_closed(self) -> None:
454471
455472 def close (self ) -> Awaitable [None ]:
456473 """Close all opened transports."""
457- self ._close ()
458- return _DeprecationWaiter (noop ())
474+ if not (waiters := self ._close ()):
475+ # If there are no connections to close, we can return a noop
476+ # awaitable to avoid scheduling a task on the event loop.
477+ return _DeprecationWaiter (noop ())
478+ coro = _wait_for_close (waiters )
479+ if sys .version_info >= (3 , 12 ):
480+ # Optimization for Python 3.12, try to close connections
481+ # immediately to avoid having to schedule the task on the event loop.
482+ task = asyncio .Task (coro , loop = self ._loop , eager_start = True )
483+ else :
484+ task = self ._loop .create_task (coro )
485+ return _DeprecationWaiter (task )
486+
487+ def _close (self ) -> List [Awaitable [object ]]:
488+ waiters : List [Awaitable [object ]] = []
459489
460- def _close (self ) -> None :
461490 if self ._closed :
462- return
491+ return waiters
463492
464493 self ._closed = True
465494
466495 try :
467496 if self ._loop .is_closed ():
468- return
497+ return waiters
469498
470499 # cancel cleanup task
471500 if self ._cleanup_handle :
@@ -476,16 +505,20 @@ def _close(self) -> None:
476505 self ._cleanup_closed_handle .cancel ()
477506
478507 for data in self ._conns .values ():
479- for proto , t0 in data :
508+ for proto , _ in data :
480509 proto .close ()
510+ waiters .append (proto .closed )
481511
482512 for proto in self ._acquired :
483513 proto .close ()
514+ waiters .append (proto .closed )
484515
485516 for transport in self ._cleanup_closed_transports :
486517 if transport is not None :
487518 transport .abort ()
488519
520+ return waiters
521+
489522 finally :
490523 self ._conns .clear ()
491524 self ._acquired .clear ()
@@ -546,7 +579,9 @@ async def connect(
546579 if (conn := await self ._get (key , traces )) is not None :
547580 return conn
548581
549- placeholder = cast (ResponseHandler , _TransportPlaceholder ())
582+ placeholder = cast (
583+ ResponseHandler , _TransportPlaceholder (self ._placeholder_future )
584+ )
550585 self ._acquired .add (placeholder )
551586 if self ._limit_per_host :
552587 self ._acquired_per_host [key ].add (placeholder )
@@ -898,15 +933,18 @@ def __init__(
898933 self ._resolve_host_tasks : Set ["asyncio.Task[List[ResolveResult]]" ] = set ()
899934 self ._socket_factory = socket_factory
900935
901- def close (self ) -> Awaitable [None ]:
936+ def _close (self ) -> List [ Awaitable [object ] ]:
902937 """Close all ongoing DNS calls."""
903938 for fut in chain .from_iterable (self ._throttle_dns_futures .values ()):
904939 fut .cancel ()
905940
941+ waiters = super ()._close ()
942+
906943 for t in self ._resolve_host_tasks :
907944 t .cancel ()
945+ waiters .append (t )
908946
909- return super (). close ()
947+ return waiters
910948
911949 @property
912950 def family (self ) -> int :
0 commit comments