Skip to content

Commit

Permalink
handle client disconnection streaming nonbatched case
Browse files Browse the repository at this point in the history
  • Loading branch information
bhimrazy committed Aug 31, 2024
1 parent 2cfd68e commit c95ee45
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 20 deletions.
14 changes: 12 additions & 2 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,13 @@ def run_batched_loop(
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue, response_queues: List[Queue]):
def run_streaming_loop(
lit_api: LitAPI,
lit_spec: LitSpec,
request_queue: Queue,
response_queues: List[Queue],
request_evicted_status: Dict[str, bool],
):
while True:
try:
response_queue_id, uid, timestamp, x_enc = request_queue.get(timeout=1.0)
Expand Down Expand Up @@ -239,6 +245,9 @@ def run_streaming_loop(lit_api: LitAPI, lit_spec: LitSpec, request_queue: Queue,
y_gen,
)
for y_enc in y_enc_gen:
if request_evicted_status.get(uid):
request_evicted_status.pop(uid)
break
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))
Expand Down Expand Up @@ -325,6 +334,7 @@ def inference_worker(
batch_timeout: float,
stream: bool,
workers_setup_status: Dict[str, bool] = None,
request_evicted_status: Dict[str, bool] = None,
):
lit_api.setup(device)
lit_api.device = device
Expand All @@ -340,7 +350,7 @@ def inference_worker(
if max_batch_size > 1:
run_batched_streaming_loop(lit_api, lit_spec, request_queue, response_queues, max_batch_size, batch_timeout)
else:
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues)
run_streaming_loop(lit_api, lit_spec, request_queue, response_queues, request_evicted_status)
return

if max_batch_size > 1:
Expand Down
44 changes: 26 additions & 18 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def response_queue_to_buffer(
await asyncio.sleep(0.0001)
continue
stream_response_buffer, event = response_buffer[uid]
stream_response_buffer.append(response)
stream_response_buffer.append((uid, response))
event.set()

else:
Expand Down Expand Up @@ -208,6 +208,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
for _ in range(num_uvicorn_servers):
response_queue = manager.Queue()
self.response_queues.append(response_queue)
self.request_evicted_status = manager.dict()

for spec in self._specs:
# Objects of Server class are referenced (not copied)
Expand Down Expand Up @@ -240,6 +241,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
self.batch_timeout,
self.stream,
self.workers_setup_status,
self.request_evicted_status,
),
)
process.start()
Expand Down Expand Up @@ -273,26 +275,32 @@ def device_identifiers(self, accelerator, device):
return [f"{accelerator}:{device}"]

async def data_streamer(self, q: deque, data_available: asyncio.Event, send_status: bool = False):
uid = None
while True:
await data_available.wait()
while len(q) > 0:
data, status = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return

if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
try:
await data_available.wait()
while len(q) > 0:
uid, (data, status) = q.popleft()
if status == LitAPIStatus.FINISH_STREAMING:
return
if status == LitAPIStatus.ERROR:
logger.error(
"Error occurred while streaming outputs from the inference worker. "
"Please check the above traceback."
)
if send_status:
yield data, status
return
if send_status:
yield data, status
return
if send_status:
yield data, status
else:
yield data
data_available.clear()
else:
yield data
data_available.clear()
except asyncio.CancelledError:
if uid is not None:
self.request_evicted_status[uid] = True
logger.exception("Streaming request cancelled for the uid=%s", uid)
return

def setup_server(self):
workers_ready = False
Expand Down

0 comments on commit c95ee45

Please sign in to comment.