@@ -313,6 +313,7 @@ def _worker_loop(
313
313
314
314
traceback .print_exc ()
315
315
logger .error (f"Error in worker loop (rank { rank } ): { e } " )
316
+ queue_out .put (e ) # any exception caught in the worker will be raised to the main process
316
317
finally :
317
318
del module
318
319
torch .cuda .synchronize ()
@@ -365,29 +366,44 @@ def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
365
366
}
366
367
)
367
368
try :
368
- _ = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
369
+ res = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
370
+ if isinstance (res , Exception ):
371
+ raise res
369
372
except Empty :
370
- logger .error ("Parallel model load LoRA timeout" )
371
- raise RuntimeError ("Parallel model load LoRA timeout" )
372
- logger .info ("Parallel model load LoRA done" )
373
+ logger .error ("ParallelModel load LoRA timeout" )
374
+ raise RuntimeError ("ParallelModel load LoRA timeout" )
375
+ except Exception as e :
376
+ logger .error (f"ParallelModel load LoRA error: { e } " )
377
+ raise RuntimeError (f"ParallelModel load LoRA error: { e } " )
378
+ logger .info ("ParallelModel load LoRA done" )
373
379
374
380
def unload_loras (self ):
375
381
self .queue_in .put ({"method" : "unload_loras" })
376
382
try :
377
- _ = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
383
+ res = self .queue_out .get (timeout = PARALLEL_LORA_TIMEOUT_SEC )
384
+ if isinstance (res , Exception ):
385
+ raise res
378
386
except Empty :
379
- logger .error ("Parallel model unload LoRA timeout" )
380
- raise RuntimeError ("Parallel model unload LoRA timeout" )
381
- logger .info ("Parallel model unload LoRA done" )
387
+ logger .error ("ParallelModel unload LoRA timeout" )
388
+ raise RuntimeError ("ParallelModel unload LoRA timeout" )
389
+ except Exception as e :
390
+ logger .error (f"ParallelModel unload LoRA error: { e } " )
391
+ raise RuntimeError (f"ParallelModel unload LoRA error: { e } " )
392
+ logger .info ("ParallelModel unload LoRA done" )
382
393
383
394
def forward (self , ** kwargs ):
384
395
self .queue_in .put (kwargs )
385
396
try :
386
- y = self .queue_out .get (timeout = PARALLEL_FWD_TIMEOUT_SEC )
397
+ res = self .queue_out .get (timeout = PARALLEL_FWD_TIMEOUT_SEC )
398
+ if isinstance (res , Exception ):
399
+ raise res
387
400
except Empty :
388
- logger .error ("Parallel model forward timeout" )
389
- raise RuntimeError ("Parallel model forward timeout" )
390
- return y
401
+ logger .error ("ParallelModel forward timeout" )
402
+ raise RuntimeError ("ParallelModel forward timeout" )
403
+ except Exception as e :
404
+ logger .error (f"ParallelModel forward error: { e } " )
405
+ raise RuntimeError (f"ParallelModel forward error: { e } " )
406
+ return res
391
407
392
408
def __del__ (self ):
393
409
# Send terminate signal to all workers
0 commit comments