Skip to content

Commit d5aabb4

Browse files
feat/typed metadata (#72)
Why === From a discussion with Jacky, egged on by the revert of #62, we should know the schema of metadata that we're threading through. What changed ============ `river.Client` is now bound by `A`, the schema of the structure passed in to `handshake_metadata`. If this agrees with the generated client's expectations, we're good to go. Test plan ========= Manual testing
1 parent 2dc3593 commit d5aabb4

File tree

6 files changed

+54
-28
lines changed

6 files changed

+54
-28
lines changed

replit_river/client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
from collections.abc import AsyncIterable, AsyncIterator
3-
from typing import Any, Callable, Optional, Union
3+
from typing import Any, Callable, Generic, Optional, TypeVar, Union
44

55
from replit_river.client_transport import ClientTransport
66
from replit_river.transport_options import TransportOptions
@@ -15,18 +15,21 @@
1515
logger = logging.getLogger(__name__)
1616

1717

18-
class Client:
18+
HandshakeType = TypeVar("HandshakeType")
19+
20+
21+
class Client(Generic[HandshakeType]):
1922
def __init__(
2023
self,
2124
websocket_uri: str,
2225
client_id: str,
2326
server_id: str,
2427
transport_options: TransportOptions,
25-
handshake_metadata: Optional[Any] = None,
28+
handshake_metadata: Optional[HandshakeType] = None,
2629
) -> None:
2730
self._client_id = client_id
2831
self._server_id = server_id
29-
self._transport = ClientTransport(
32+
self._transport = ClientTransport[HandshakeType](
3033
websocket_uri=websocket_uri,
3134
client_id=client_id,
3235
server_id=server_id,

replit_river/client_transport.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import logging
3-
from typing import Any, Optional, Tuple
3+
from typing import Generic, Optional, Tuple, TypeVar
44

55
import websockets
66
from pydantic import ValidationError
@@ -41,14 +41,17 @@
4141
logger = logging.getLogger(__name__)
4242

4343

44-
class ClientTransport(Transport):
44+
HandshakeType = TypeVar("HandshakeType")
45+
46+
47+
class ClientTransport(Transport, Generic[HandshakeType]):
4548
def __init__(
4649
self,
4750
websocket_uri: str,
4851
client_id: str,
4952
server_id: str,
5053
transport_options: TransportOptions,
51-
handshake_metadata: Optional[Any] = None,
54+
handshake_metadata: Optional[HandshakeType] = None,
5255
):
5356
super().__init__(
5457
transport_id=client_id,
@@ -91,7 +94,7 @@ async def _establish_new_connection(
9194
old_session: Optional[ClientSession] = None,
9295
) -> Tuple[
9396
WebSocketCommonProtocol,
94-
ControlMessageHandshakeRequest,
97+
ControlMessageHandshakeRequest[HandshakeType],
9598
ControlMessageHandshakeResponse,
9699
]:
97100
"""Build a new websocket connection with retry logic."""
@@ -204,11 +207,11 @@ async def _send_handshake_request(
204207
transport_id: str,
205208
to_id: str,
206209
session_id: str,
207-
handshake_metadata: Optional[Any],
210+
handshake_metadata: Optional[HandshakeType],
208211
websocket: WebSocketCommonProtocol,
209212
expected_session_state: ExpectedSessionState,
210-
) -> ControlMessageHandshakeRequest:
211-
handshake_request = ControlMessageHandshakeRequest(
213+
) -> ControlMessageHandshakeRequest[HandshakeType]:
214+
handshake_request = ControlMessageHandshakeRequest[HandshakeType](
212215
type="HANDSHAKE_REQ",
213216
protocolVersion=PROTOCOL_VERSION,
214217
sessionId=session_id,
@@ -273,10 +276,12 @@ async def _establish_handshake(
273276
transport_id: str,
274277
to_id: str,
275278
session_id: str,
276-
handshake_metadata: Optional[Any],
279+
handshake_metadata: Optional[HandshakeType],
277280
websocket: WebSocketCommonProtocol,
278281
old_session: Optional[ClientSession],
279-
) -> Tuple[ControlMessageHandshakeRequest, ControlMessageHandshakeResponse]:
282+
) -> Tuple[
283+
ControlMessageHandshakeRequest[HandshakeType], ControlMessageHandshakeResponse
284+
]:
280285
try:
281286
handshake_request = await self._send_handshake_request(
282287
transport_id=transport_id,

replit_river/codegen/client.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class RiverService(BaseModel):
6060

6161
class RiverSchema(BaseModel):
6262
services: Dict[str, RiverService]
63+
handshakeSchema: RiverConcreteType
6364

6465

6566
RiverSchemaFile = RootModel[RiverSchema]
@@ -264,12 +265,21 @@ def generate_river_client_module(
264265
"import replit_river as river",
265266
"",
266267
]
268+
269+
(handshake_type, handshake_chunks) = encode_type(
270+
schema_root.handshakeSchema, "HandshakeSchema"
271+
)
272+
chunks.extend(handshake_chunks)
273+
267274
for schema_name, schema in schema_root.services.items():
268275
current_chunks: List[str] = [
269-
f"class {schema_name.title()}Service:",
270-
" def __init__(self, client: river.Client):",
271-
" self.client = client",
272-
"",
276+
dedent(
277+
f"""\
278+
class {schema_name.title()}Service:
279+
def __init__(self, client: river.Client[{handshake_type}]):
280+
self.client = client
281+
"""
282+
),
273283
]
274284
for name, procedure in schema.procedures.items():
275285
init_type: Optional[str] = None
@@ -309,13 +319,13 @@ def generate_river_client_module(
309319
.validate_python(
310320
x # type: ignore[arg-type]
311321
)
312-
""".strip()
322+
""".rstrip()
313323
parse_error_method = f"""\
314324
lambda x: TypeAdapter({error_type})
315325
.validate_python(
316326
x # type: ignore[arg-type]
317327
)
318-
""".strip()
328+
""".rstrip()
319329

320330
if output_type == "None":
321331
parse_output_method = "lambda x: None"
@@ -506,8 +516,12 @@ async def {name}(
506516

507517
chunks.extend(
508518
[
509-
f"class {client_name}:",
510-
" def __init__(self, client: river.Client):",
519+
dedent(
520+
f"""\
521+
class {client_name}:
522+
def __init__(self, client: river.Client[{handshake_type}]):
523+
""".rstrip()
524+
)
511525
]
512526
)
513527
for schema_name, schema in schema_root.services.items():

replit_river/rpc.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Callable,
88
Coroutine,
99
Dict,
10+
Generic,
1011
Iterable,
1112
Literal,
1213
Mapping,
@@ -61,12 +62,15 @@ class ExpectedSessionState(BaseModel):
6162
nextSentSeq: Optional[int] = None
6263

6364

64-
class ControlMessageHandshakeRequest(BaseModel):
65+
HandshakeType = TypeVar("HandshakeType")
66+
67+
68+
class ControlMessageHandshakeRequest(BaseModel, Generic[HandshakeType]):
6569
type: Literal["HANDSHAKE_REQ"] = "HANDSHAKE_REQ"
6670
protocolVersion: str
6771
sessionId: str
6872
expectedSessionState: ExpectedSessionState
69-
metadata: Optional[Any] = None
73+
metadata: Optional[HandshakeType] = None
7074

7175

7276
class HandShakeStatus(BaseModel):

replit_river/server_transport.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Tuple
2+
from typing import Any, Tuple
33

44
import nanoid # type: ignore # type: ignore
55
from pydantic import ValidationError
@@ -113,11 +113,11 @@ async def websocket_closed_callback() -> None:
113113
async def _establish_handshake(
114114
self, request_message: TransportMessage, websocket: WebSocketCommonProtocol
115115
) -> Tuple[
116-
ControlMessageHandshakeRequest,
116+
ControlMessageHandshakeRequest[Any],
117117
ControlMessageHandshakeResponse,
118118
]:
119119
try:
120-
handshake_request = ControlMessageHandshakeRequest(
120+
handshake_request = ControlMessageHandshakeRequest[Any](
121121
**request_message.payload
122122
)
123123
logger.debug('Got handshake request "%r"', handshake_request)

tests/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import asyncio
22
import logging
33
from collections.abc import AsyncIterator
4-
from typing import Any, AsyncGenerator
4+
from typing import Any, AsyncGenerator, NoReturn
55

66
import nanoid # type: ignore
77
import pytest
@@ -139,7 +139,7 @@ async def client(
139139
) -> AsyncGenerator[Client, None]:
140140
try:
141141
async with serve(server.serve, "localhost", 8765):
142-
client = Client(
142+
client: Client[NoReturn] = Client(
143143
"ws://localhost:8765",
144144
client_id="test_client",
145145
server_id="test_server",

0 commit comments

Comments
 (0)