1
1
import asyncio
2
2
import logging
3
+ from collections .abc import Awaitable , Callable
3
4
from typing import Generic , Optional , Tuple , TypeVar
4
5
5
6
import websockets
47
48
class ClientTransport (Transport , Generic [HandshakeType ]):
48
49
def __init__ (
49
50
self ,
50
- websocket_uri : str ,
51
+ websocket_uri_factory : Callable [[], Awaitable [ str ]] ,
51
52
client_id : str ,
52
53
server_id : str ,
53
54
transport_options : TransportOptions ,
54
- handshake_metadata : Optional [HandshakeType ] = None ,
55
+ handshake_metadata_factory : Optional [Callable [[], Awaitable [ HandshakeType ]] ] = None ,
55
56
):
56
57
super ().__init__ (
57
58
transport_id = client_id ,
58
59
transport_options = transport_options ,
59
60
is_server = False ,
60
61
)
61
- self ._websocket_uri = websocket_uri
62
+ self ._websocket_uri_factory = websocket_uri_factory
62
63
self ._client_id = client_id
63
64
self ._server_id = server_id
64
65
self ._rate_limiter = LeakyBucketRateLimit (
65
66
transport_options .connection_retry_options
66
67
)
67
- self ._handshake_metadata = handshake_metadata
68
+ self ._handshake_metadata_factory = handshake_metadata_factory
68
69
# We want to make sure there's only one session creation at a time
69
70
self ._create_session_lock = asyncio .Lock ()
70
71
@@ -110,12 +111,18 @@ async def _establish_new_connection(
110
111
break
111
112
rate_limit .consume_budget (client_id )
112
113
try :
113
- ws = await websockets .connect (self ._websocket_uri )
114
+ websocket_uri = await self ._websocket_uri_factory ()
115
+ ws = await websockets .connect (websocket_uri )
114
116
session_id = (
115
117
self .generate_session_id ()
116
118
if not old_session
117
119
else old_session .session_id
118
120
)
121
+
122
+ handshake_metadata = None
123
+ if self ._handshake_metadata_factory is not None :
124
+ handshake_metadata = await self ._handshake_metadata_factory ()
125
+
119
126
try :
120
127
(
121
128
handshake_request ,
@@ -124,7 +131,7 @@ async def _establish_new_connection(
124
131
self ._transport_id ,
125
132
self ._server_id ,
126
133
session_id ,
127
- self . _handshake_metadata ,
134
+ handshake_metadata ,
128
135
ws ,
129
136
old_session ,
130
137
)
0 commit comments