@@ -88,7 +88,7 @@ async def response_queue_to_buffer(
88
88
await asyncio .sleep (0.0001 )
89
89
continue
90
90
stream_response_buffer , event = response_buffer [uid ]
91
- stream_response_buffer .append (response )
91
+ stream_response_buffer .append (( uid , response ) )
92
92
event .set ()
93
93
94
94
else :
@@ -208,6 +208,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
208
208
for _ in range (num_uvicorn_servers ):
209
209
response_queue = manager .Queue ()
210
210
self .response_queues .append (response_queue )
211
+ self .request_evicted_status = manager .dict ()
211
212
212
213
for spec in self ._specs :
213
214
# Objects of Server class are referenced (not copied)
@@ -240,6 +241,7 @@ def launch_inference_worker(self, num_uvicorn_servers: int):
240
241
self .batch_timeout ,
241
242
self .stream ,
242
243
self .workers_setup_status ,
244
+ self .request_evicted_status ,
243
245
),
244
246
)
245
247
process .start ()
@@ -273,26 +275,32 @@ def device_identifiers(self, accelerator, device):
273
275
return [f"{ accelerator } :{ device } " ]
274
276
275
277
async def data_streamer (self , q : deque , data_available : asyncio .Event , send_status : bool = False ):
278
+ uid = None
276
279
while True :
277
- await data_available .wait ()
278
- while len (q ) > 0 :
279
- data , status = q .popleft ()
280
- if status == LitAPIStatus .FINISH_STREAMING :
281
- return
282
-
283
- if status == LitAPIStatus .ERROR :
284
- logger .error (
285
- "Error occurred while streaming outputs from the inference worker. "
286
- "Please check the above traceback."
287
- )
280
+ try :
281
+ await data_available .wait ()
282
+ while len (q ) > 0 :
283
+ uid , (data , status ) = q .popleft ()
284
+ if status == LitAPIStatus .FINISH_STREAMING :
285
+ return
286
+ if status == LitAPIStatus .ERROR :
287
+ logger .error (
288
+ "Error occurred while streaming outputs from the inference worker. "
289
+ "Please check the above traceback."
290
+ )
291
+ if send_status :
292
+ yield data , status
293
+ return
288
294
if send_status :
289
295
yield data , status
290
- return
291
- if send_status :
292
- yield data , status
293
- else :
294
- yield data
295
- data_available .clear ()
296
+ else :
297
+ yield data
298
+ data_available .clear ()
299
+ except asyncio .CancelledError :
300
+ if uid is not None :
301
+ self .request_evicted_status [uid ] = True
302
+ logger .exception ("Streaming request cancelled for the uid=%s" , uid )
303
+ return
296
304
297
305
def setup_server (self ):
298
306
workers_ready = False
0 commit comments