Skip to content

Commit 232ada2

Browse files
authored
raise parallel error in main process (#66)
1 parent 9d8a62e commit 232ada2

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

diffsynth_engine/utils/parallel.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ def _worker_loop(
313313

314314
traceback.print_exc()
315315
logger.error(f"Error in worker loop (rank {rank}): {e}")
316+
queue_out.put(e) # any exception caught in the worker will be raised to the main process
316317
finally:
317318
del module
318319
torch.cuda.synchronize()
@@ -365,29 +366,44 @@ def load_loras(self, lora_args: List[Dict[str, any]], fused: bool = True):
365366
}
366367
)
367368
try:
368-
_ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
369+
res = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
370+
if isinstance(res, Exception):
371+
raise res
369372
except Empty:
370-
logger.error("Parallel model load LoRA timeout")
371-
raise RuntimeError("Parallel model load LoRA timeout")
372-
logger.info("Parallel model load LoRA done")
373+
logger.error("ParallelModel load LoRA timeout")
374+
raise RuntimeError("ParallelModel load LoRA timeout")
375+
except Exception as e:
376+
logger.error(f"ParallelModel load LoRA error: {e}")
377+
raise RuntimeError(f"ParallelModel load LoRA error: {e}")
378+
logger.info("ParallelModel load LoRA done")
373379

374380
def unload_loras(self):
375381
self.queue_in.put({"method": "unload_loras"})
376382
try:
377-
_ = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
383+
res = self.queue_out.get(timeout=PARALLEL_LORA_TIMEOUT_SEC)
384+
if isinstance(res, Exception):
385+
raise res
378386
except Empty:
379-
logger.error("Parallel model unload LoRA timeout")
380-
raise RuntimeError("Parallel model unload LoRA timeout")
381-
logger.info("Parallel model unload LoRA done")
387+
logger.error("ParallelModel unload LoRA timeout")
388+
raise RuntimeError("ParallelModel unload LoRA timeout")
389+
except Exception as e:
390+
logger.error(f"ParallelModel unload LoRA error: {e}")
391+
raise RuntimeError(f"ParallelModel unload LoRA error: {e}")
392+
logger.info("ParallelModel unload LoRA done")
382393

383394
def forward(self, **kwargs):
384395
self.queue_in.put(kwargs)
385396
try:
386-
y = self.queue_out.get(timeout=PARALLEL_FWD_TIMEOUT_SEC)
397+
res = self.queue_out.get(timeout=PARALLEL_FWD_TIMEOUT_SEC)
398+
if isinstance(res, Exception):
399+
raise res
387400
except Empty:
388-
logger.error("Parallel model forward timeout")
389-
raise RuntimeError("Parallel model forward timeout")
390-
return y
401+
logger.error("ParallelModel forward timeout")
402+
raise RuntimeError("ParallelModel forward timeout")
403+
except Exception as e:
404+
logger.error(f"ParallelModel forward error: {e}")
405+
raise RuntimeError(f"ParallelModel forward error: {e}")
406+
return res
391407

392408
def __del__(self):
393409
# Send terminate signal to all workers

0 commit comments

Comments
 (0)