Skip to content

Commit de0d5b0

Browse files
committedAug 20, 2024
Revert "Revert "Make client accept a function for websocket uri and hadnshakemetadata (#62) (#71)"
This reverts commit 2dc3593.
1 parent d5aabb4 commit de0d5b0

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed
 

‎replit_river/client.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from collections.abc import AsyncIterable, AsyncIterator
2+
from collections.abc import AsyncIterable, AsyncIterator, Awaitable
33
from typing import Any, Callable, Generic, Optional, TypeVar, Union
44

55
from replit_river.client_transport import ClientTransport
@@ -21,20 +21,20 @@
2121
class Client(Generic[HandshakeType]):
2222
def __init__(
2323
self,
24-
websocket_uri: str,
24+
websocket_uri_factory: Callable[[], Awaitable[str]],
2525
client_id: str,
2626
server_id: str,
2727
transport_options: TransportOptions,
28-
handshake_metadata: Optional[HandshakeType] = None,
28+
handshake_metadata_factory: Optional[Callable[[], Awaitable[HandshakeType]]] = None,
2929
) -> None:
3030
self._client_id = client_id
3131
self._server_id = server_id
3232
self._transport = ClientTransport[HandshakeType](
33-
websocket_uri=websocket_uri,
33+
websocket_uri_factory=websocket_uri_factory,
3434
client_id=client_id,
3535
server_id=server_id,
3636
transport_options=transport_options,
37-
handshake_metadata=handshake_metadata,
37+
handshake_metadata_factory=handshake_metadata_factory,
3838
)
3939

4040
async def close(self) -> None:

‎replit_river/client_transport.py

+13-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import asyncio
22
import logging
3+
from collections.abc import Awaitable, Callable
34
from typing import Generic, Optional, Tuple, TypeVar
45

56
import websockets
@@ -47,24 +48,24 @@
4748
class ClientTransport(Transport, Generic[HandshakeType]):
4849
def __init__(
4950
self,
50-
websocket_uri: str,
51+
websocket_uri_factory: Callable[[], Awaitable[str]],
5152
client_id: str,
5253
server_id: str,
5354
transport_options: TransportOptions,
54-
handshake_metadata: Optional[HandshakeType] = None,
55+
handshake_metadata_factory: Optional[Callable[[], Awaitable[HandshakeType]]] = None,
5556
):
5657
super().__init__(
5758
transport_id=client_id,
5859
transport_options=transport_options,
5960
is_server=False,
6061
)
61-
self._websocket_uri = websocket_uri
62+
self._websocket_uri_factory = websocket_uri_factory
6263
self._client_id = client_id
6364
self._server_id = server_id
6465
self._rate_limiter = LeakyBucketRateLimit(
6566
transport_options.connection_retry_options
6667
)
67-
self._handshake_metadata = handshake_metadata
68+
self._handshake_metadata_factory = handshake_metadata_factory
6869
# We want to make sure there's only one session creation at a time
6970
self._create_session_lock = asyncio.Lock()
7071

@@ -110,12 +111,18 @@ async def _establish_new_connection(
110111
break
111112
rate_limit.consume_budget(client_id)
112113
try:
113-
ws = await websockets.connect(self._websocket_uri)
114+
websocket_uri = await self._websocket_uri_factory()
115+
ws = await websockets.connect(websocket_uri)
114116
session_id = (
115117
self.generate_session_id()
116118
if not old_session
117119
else old_session.session_id
118120
)
121+
122+
handshake_metadata = None
123+
if self._handshake_metadata_factory is not None:
124+
handshake_metadata = await self._handshake_metadata_factory()
125+
119126
try:
120127
(
121128
handshake_request,
@@ -124,7 +131,7 @@ async def _establish_new_connection(
124131
self._transport_id,
125132
self._server_id,
126133
session_id,
127-
self._handshake_metadata,
134+
handshake_metadata,
128135
ws,
129136
old_session,
130137
)

‎tests/conftest.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,14 @@ async def client(
137137
transport_options: TransportOptions,
138138
no_logging_error: NoErrors,
139139
) -> AsyncGenerator[Client, None]:
140+
141+
async def websocket_uri_factory() -> str:
142+
return "ws://localhost:8765"
143+
140144
try:
141145
async with serve(server.serve, "localhost", 8765):
142146
client: Client[NoReturn] = Client(
143-
"ws://localhost:8765",
147+
websocket_uri_factory,
144148
client_id="test_client",
145149
server_id="test_server",
146150
transport_options=transport_options,

0 commit comments

Comments
 (0)