Skip to content

Commit c95ee45

Browse files
committed
handle client disconnection streaming nonbatched case
1 parent 2cfd68e commit c95ee45

File tree

2 files changed

+38
-20
lines changed

2 files changed

+38
-20
lines changed

src/litserve/loops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,13 @@ def run_batched_loop(
200200
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
201201

202202

203-
def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
203+
def run_streaming_loop(
204+
lit_api: LitAPI,
205+
lit_spec: LitSpec,
206+
request_queue: Queue,
207+
response_queues: List[Queue],
208+
request_evicted_status: Dict[str, bool],
209+
):
204210
while True:
205211
try:
206212
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
@@ -239,6 +245,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
239245
y_gen,
240246
)
241247
for y_enc in y_enc_gen:
248+
if request_evicted_status.get(uid):
249+
request_evicted_status.pop(uid)
250+
break
242251
y_enc = lit_api.format_encoded_response(y_enc)
243252
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
244253
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
@@ -325,6 +334,7 @@ def inference_worker(
325334
batch_timeout: float,
326335
stream: bool,
327336
workers_setup_status: Dict[str, bool] = None,
337+
request_evicted_status: Dict[str, bool] = None,
328338
):
329339
lit_api.setup(device)
330340
lit_api.device = device
@@ -340,7 +350,7 @@ def inference_worker(
340350
if max_batch_size > 1:
341351
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
342352
else:
343-
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
353+
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status)
344354
return
345355

346356
if max_batch_size > 1:

src/litserve/server.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ async def response_queue_to_buffer(
8888
await asyncio.sleep(0.0001)
8989
continue
9090
stream_response_buffer, event = response_buffer[uid]
91-
stream_response_buffer.append(response)
91+
stream_response_buffer.append((uid, response))
9292
event.set()
9393

9494
else:
@@ -208,6 +208,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
208208
for _ in range(num_uvicorn_servers):
209209
response_queue = manager.Queue()
210210
self.response_queues.append(response_queue)
211+
self.request_evicted_status = manager.dict()
211212

212213
for spec in self._specs:
213214
# Objects of Server class are referenced (not copied)
@@ -240,6 +241,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
240241
self.batch_timeout,
241242
self.stream,
242243
self.workers_setup_status,
244+
self.request_evicted_status,
243245
),
244246
)
245247
process.start()
@@ -273,26 +275,32 @@ def device_identifiers(self, accelerator, device):
273275
return [f"{accelerator}:{device}"]
274276

275277
async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
278+
uid = None
276279
while True:
277-
await data_available.wait()
278-
while len(q) > 0:
279-
data, status = q.popleft()
280-
if status == LitAPIStatus.FINISH_STREAMING:
281-
return
282-
283-
if status == LitAPIStatus.ERROR:
284-
logger.error(
285-
"Error occurred while streaming outputs from the inference worker. "
286-
"Please check the above traceback."
287-
)
280+
try:
281+
await data_available.wait()
282+
while len(q) > 0:
283+
uid, (data, status) = q.popleft()
284+
if status == LitAPIStatus.FINISH_STREAMING:
285+
return
286+
if status == LitAPIStatus.ERROR:
287+
logger.error(
288+
"Error occurred while streaming outputs from the inference worker. "
289+
"Please check the above traceback."
290+
)
291+
if send_status:
292+
yield data, status
293+
return
288294
if send_status:
289295
yield data, status
290-
return
291-
if send_status:
292-
yield data, status
293-
else:
294-
yield data
295-
data_available.clear()
296+
else:
297+
yield data
298+
data_available.clear()
299+
except asyncio.CancelledError:
300+
if uid is not None:
301+
self.request_evicted_status[uid] = True
302+
logger.exception("Streaming request cancelled for the uid=%s", uid)
303+
return
296304

297305
def setup_server(self):
298306
workers_ready = False

0 commit comments

Comments
 (0)