From c95ee4540af5db40b5fd10d6279f2094c315edb0 Mon Sep 17 00:00:00 2001 From: bhimrazy Date: Sun, 1 Sep 2024 01:42:52 +0545 Subject: [PATCH] handle client disconnection streaming nonbatched case --- src/litserve/loops.py | 14 ++++++++++++-- src/litserve/server.py | 44 +++++++++++++++++++++++++----------------- 2 files changed, 38 insertions(+), 20 deletions(-) diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 1c73b276..8093260a 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -200,7 +200,13 @@ def run_batched_loop( response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) -def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]): +def run_streaming_loop( + lit_api: LitAPI, + lit_spec: LitSpec, + request_queue: Queue, + response_queues: List[Queue], + request_evicted_status: Dict[str, bool], +): while True: try: response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0) @@ -239,6 +245,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, y_gen, ) for y_enc in y_enc_gen: + if request_evicted_status.get(uid): + request_evicted_status.pop(uid) + break y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) @@ -325,6 +334,7 @@ def inference_worker( batch_timeout: float, stream: bool, workers_setup_status: Dict[str, bool] = None, + request_evicted_status: Dict[str, bool] = None, ): lit_api.setup(device) lit_api.device = device @@ -340,7 +350,7 @@ def inference_worker( if max_batch_size > 1: run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout) else: - run_streaming_loop(lit_api, lit_spec, request_queue, response_queues) + run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status) return if max_batch_size > 1: diff --git a/src/litserve/server.py b/src/litserve/server.py index d4e76000..37de1cc5 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -88,7 +88,7 @@ async def response_queue_to_buffer( await asyncio.sleep(0.0001) continue stream_response_buffer, event = response_buffer[uid] - stream_response_buffer.append(response) + stream_response_buffer.append((uid, response)) event.set() else: @@ -208,6 +208,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): for _ in range(num_uvicorn_servers): response_queue = manager.Queue() self.response_queues.append(response_queue) + self.request_evicted_status = manager.dict() for spec in self._specs: # Objects of Server class are referenced (not copied) @@ -240,6 +241,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int): self.batch_timeout, self.stream, self.workers_setup_status, + self.request_evicted_status, ), ) process.start() @@ -273,26 +275,32 @@ def device_identifiers(self, accelerator, device): return [f"{accelerator}:{device}"] async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False): + uid = None while True: - await data_available.wait() - while len(q) > 0: - data, status = q.popleft() - if status == LitAPIStatus.FINISH_STREAMING: - return - - if status == LitAPIStatus.ERROR: - logger.error( - "Error occurred while streaming outputs from the inference worker. " - "Please check the above traceback." - ) + try: + await data_available.wait() + while len(q) > 0: + uid, (data, status) = q.popleft() + if status == LitAPIStatus.FINISH_STREAMING: + return + if status == LitAPIStatus.ERROR: + logger.error( + "Error occurred while streaming outputs from the inference worker. " + "Please check the above traceback." + ) + if send_status: + yield data, status + return if send_status: yield data, status - return - if send_status: - yield data, status - else: - yield data - data_available.clear() + else: + yield data + data_available.clear() + except asyncio.CancelledError: + if uid is not None: + self.request_evicted_status[uid] = True + logger.exception("Streaming request cancelled for the uid=%s", uid) + return def setup_server(self): workers_ready = False