@@ -107,6 +107,9 @@ def __init__(
107107 self ._buffer = MessageBuffer (self ._transport_options .buffer_size )
108108 self ._task_manager = BackgroundTaskManager ()
109109
110+ # Start the buffered message sender task
111+ self ._start_buffered_message_sender ()
112+
110113 def _setup_heartbeats_task (
111114 self ,
112115 do_close_websocket : Callable [[], Awaitable [None ]],
@@ -142,6 +145,38 @@ def increment_and_get_heartbeat_misses() -> int:
142145 )
143146 )
144147
148+ def _start_buffered_message_sender (self ) -> None :
149+ """Start the background task that sends messages from the buffer."""
150+ from replit_river .common_session import buffered_message_sender
151+
152+ async def commit (msg : TransportMessage ) -> None :
153+ # Remove messages that have been acknowledged
154+ await self ._buffer .remove_old_messages (msg .seq + 1 )
155+
156+ def get_next_pending () -> TransportMessage | None :
157+ return self ._buffer .peek ()
158+
159+ def get_ws () -> websockets .WebSocketCommonProtocol | None :
160+ if self ._ws_wrapper .is_open ():
161+ return self ._ws_wrapper .ws
162+ return None
163+
164+ async def block_until_connected () -> None :
165+ while self ._state in [SessionState .NO_CONNECTION , SessionState .CONNECTING ]:
166+ await asyncio .sleep (0.1 )
167+
168+ self ._task_manager .create_task (
169+ buffered_message_sender (
170+ block_until_connected = block_until_connected ,
171+ block_until_message_available = self ._buffer .block_until_message_available ,
172+ get_ws = get_ws ,
173+ websocket_closed_callback = self ._begin_close_session_countdown ,
174+ get_next_pending = get_next_pending ,
175+ commit = commit ,
176+ get_state = lambda : self ._state ,
177+ )
178+ )
179+
145180 async def is_session_open (self ) -> bool :
146181 async with self ._state_lock :
147182 return self ._state == SessionState .ACTIVE
@@ -181,24 +216,6 @@ async def replace_with_new_websocket(
181216 await old_wrapper .close ()
182217 self ._ws_wrapper = WebsocketWrapper (new_ws )
183218
184- # Send buffered messages to the new ws
185- buffered_messages = list (self ._buffer .buffer )
186- for msg in buffered_messages :
187- try :
188- await send_transport_message (
189- msg ,
190- new_ws ,
191- self ._begin_close_session_countdown ,
192- )
193- except WebsocketClosedException :
194- logger .info (
195- "Connection closed while sending buffered messages" , exc_info = True
196- )
197- break
198- except FailedSendingMessageException :
199- logger .exception ("Error while sending buffered messages" )
200- break
201-
202219 async def _get_current_time (self ) -> float :
203220 return asyncio .get_event_loop ().time ()
204221
@@ -249,30 +266,10 @@ async def send_message(
249266 with use_span (span ):
250267 trace_propagator .inject (msg , None , trace_setter )
251268 try :
252- try :
253- self ._buffer .put (msg )
254- except MessageBufferClosedError :
255- # The session is closed and is no longer accepting new messages.
256- return
257- async with self ._ws_lock :
258- if not self ._ws_wrapper .is_open ():
259- # If the websocket is closed, we should not send the message
260- # and wait for the retry from the buffer.
261- return
262- await send_transport_message (
263- msg , self ._ws_wrapper .ws , self ._begin_close_session_countdown
264- )
265- except WebsocketClosedException as e :
266- logger .debug (
267- "Connection closed while sending message %r, waiting for "
268- "retry from buffer" ,
269- type (e ),
270- exc_info = e ,
271- )
272- except FailedSendingMessageException :
273- logger .error (
274- "Failed sending message, waiting for retry from buffer" , exc_info = True
275- )
269+ self ._buffer .put (msg )
270+ except MessageBufferClosedError :
271+ # The session is closed and is no longer accepting new messages.
272+ return
276273
277274 async def close_websocket (
278275 self , ws_wrapper : WebsocketWrapper , should_retry : bool
0 commit comments