Skip to content

Commit c837786

Browse files
authored
Properly close streams on exception (#167)
Why === We've had persistent timeout errors in AI-Infra, and I suspect that it's related to not handling bumps in the connection correctly. What changed ============ - WebSocket drops or send failures left _streams populated, so any in-flight RPC hung until the session fully shut down. That meant clients didn’t see an abort signal and could block indefinitely even though the transport was already defunct. - Added _abort_all_streams() in src/replit_river/session.py#L289 and call it from both client_session.serve() and server_session.serve() on ConnectionClosed, FailedSendingMessageException, or any other unexpected exception (src/replit_river/client_session.py#L95, src/replit_river/server_session.py#L82). This immediately closes every active channel and clears _streams, ensuring callers are notified right away when the socket dies so they can retry or surface an error. Test plan ========= CI/CD, ran against an internal branch with no issues 3x without flake.
1 parent dcec28b commit c837786

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/replit_river/client_session.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,17 +99,22 @@ async def serve(self) -> None:
9999
try:
100100
await self._handle_messages_from_ws()
101101
except ConnectionClosed:
102+
if self._should_abort_streams_after_transport_failure():
103+
await self.close()
102104
if self._retry_connection_callback:
103105
self._task_manager.create_task(self._retry_connection_callback())
104106

105107
await self._begin_close_session_countdown()
106108
logger.debug("ConnectionClosed while serving", exc_info=True)
107109
except FailedSendingMessageException:
108110
# Expected error if the connection is closed.
111+
if self._should_abort_streams_after_transport_failure():
112+
await self.close()
109113
logger.debug(
110114
"FailedSendingMessageException while serving", exc_info=True
111115
)
112116
except Exception:
117+
await self.close()
113118
logger.exception("caught exception at message iterator")
114119
except ExceptionGroup as eg:
115120
_, unhandled = eg.split(lambda e: isinstance(e, ConnectionClosed))

src/replit_river/session.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,17 @@ async def close_websocket(
286286
if should_retry and self._retry_connection_callback:
287287
self._task_manager.create_task(self._retry_connection_callback())
288288

289+
def _should_abort_streams_after_transport_failure(self) -> bool:
290+
return not self._transport_options.transparent_reconnect
291+
292+
def _abort_all_streams(self) -> None:
293+
"""Close all active stream channels, notifying any waiting consumers."""
294+
if not self._streams:
295+
return
296+
for stream in self._streams.values():
297+
stream.close()
298+
self._streams.clear()
299+
289300
async def close(self) -> None:
290301
"""Close the session and all associated streams."""
291302
logger.info(
@@ -310,9 +321,7 @@ async def close(self) -> None:
310321

311322
# TODO: unexpected_close should close stream differently here to
312323
# throw exception correctly.
313-
for stream in self._streams.values():
314-
stream.close()
315-
self._streams.clear()
324+
self._abort_all_streams()
316325

317326
self._state = SessionState.CLOSED
318327

0 commit comments

Comments
 (0)