Skip to content

Commit 2dc3593

Browse files
Revert "Make client accept a function for websocket uri and hadnshakemetadata (#62) (#71)
Why === This reverts commit 71ba3df. We've been seeing some issues attempting to negotiate tokens successfully, this seems like it is most likely to be the cause. What changed ============ Reverting async metadata for now Test plan ========= Does @airportyh's manual testing pass again?
1 parent cdef2f3 commit 2dc3593

File tree

3 files changed

+13
-26
lines changed

3 files changed

+13
-26
lines changed

replit_river/client.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
2-
from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Callable
3-
from typing import Any, Optional, Union
2+
from collections.abc import AsyncIterable, AsyncIterator
3+
from typing import Any, Callable, Optional, Union
44

55
from replit_river.client_transport import ClientTransport
66
from replit_river.transport_options import TransportOptions
@@ -16,23 +16,22 @@
1616

1717

1818
class Client:
19-
2019
def __init__(
2120
self,
22-
websocket_uri_factory: Callable[[], Awaitable[str]],
21+
websocket_uri: str,
2322
client_id: str,
2423
server_id: str,
2524
transport_options: TransportOptions,
26-
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
25+
handshake_metadata: Optional[Any] = None,
2726
) -> None:
2827
self._client_id = client_id
2928
self._server_id = server_id
3029
self._transport = ClientTransport(
31-
websocket_uri_factory=websocket_uri_factory,
30+
websocket_uri=websocket_uri,
3231
client_id=client_id,
3332
server_id=server_id,
3433
transport_options=transport_options,
35-
handshake_metadata_factory=handshake_metadata_factory,
34+
handshake_metadata=handshake_metadata,
3635
)
3736

3837
async def close(self) -> None:

replit_river/client_transport.py

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import asyncio
22
import logging
3-
from collections.abc import Awaitable, Callable
43
from typing import Any, Optional, Tuple
54

65
import websockets
@@ -43,27 +42,26 @@
4342

4443

4544
class ClientTransport(Transport):
46-
4745
def __init__(
4846
self,
49-
websocket_uri_factory: Callable[[], Awaitable[str]],
47+
websocket_uri: str,
5048
client_id: str,
5149
server_id: str,
5250
transport_options: TransportOptions,
53-
handshake_metadata_factory: Optional[Callable[[], Awaitable[Any]]] = None,
51+
handshake_metadata: Optional[Any] = None,
5452
):
5553
super().__init__(
5654
transport_id=client_id,
5755
transport_options=transport_options,
5856
is_server=False,
5957
)
60-
self._websocket_uri_factory = websocket_uri_factory
58+
self._websocket_uri = websocket_uri
6159
self._client_id = client_id
6260
self._server_id = server_id
6361
self._rate_limiter = LeakyBucketRateLimit(
6462
transport_options.connection_retry_options
6563
)
66-
self._handshake_metadata_factory = handshake_metadata_factory
64+
self._handshake_metadata = handshake_metadata
6765
# We want to make sure there's only one session creation at a time
6866
self._create_session_lock = asyncio.Lock()
6967

@@ -109,18 +107,12 @@ async def _establish_new_connection(
109107
break
110108
rate_limit.consume_budget(client_id)
111109
try:
112-
websocket_uri = await self._websocket_uri_factory()
113-
ws = await websockets.connect(websocket_uri)
110+
ws = await websockets.connect(self._websocket_uri)
114111
session_id = (
115112
self.generate_session_id()
116113
if not old_session
117114
else old_session.session_id
118115
)
119-
120-
handshake_metadata = None
121-
if self._handshake_metadata_factory is not None:
122-
handshake_metadata = await self._handshake_metadata_factory()
123-
124116
try:
125117
(
126118
handshake_request,
@@ -129,7 +121,7 @@ async def _establish_new_connection(
129121
self._transport_id,
130122
self._server_id,
131123
session_id,
132-
handshake_metadata,
124+
self._handshake_metadata,
133125
ws,
134126
old_session,
135127
)

tests/conftest.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,10 @@ async def client(
137137
transport_options: TransportOptions,
138138
no_logging_error: NoErrors,
139139
) -> AsyncGenerator[Client, None]:
140-
141-
async def websocket_uri_factory() -> str:
142-
return "ws://localhost:8765"
143-
144140
try:
145141
async with serve(server.serve, "localhost", 8765):
146142
client = Client(
147-
websocket_uri_factory,
143+
"ws://localhost:8765",
148144
client_id="test_client",
149145
server_id="test_server",
150146
transport_options=transport_options,

0 commit comments

Comments
 (0)