diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index 1d494d51..3b0e96c6 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -232,7 +232,7 @@ def sync_broadcast(self, actors, group_name, requires_grad=None): for rank, actor in enumerate(actors): ref = actor.broadcast_parameter.remote(rank, 0, group_name, pipe_stage) refs.append(ref) - future.wait(refs) + future.wait(refs, return_output=True) def _sync_send_recv(self, send_actor, recv_actor, requires_grad=None): @@ -429,6 +429,11 @@ def sync(self, requires_grad=None): else: for recv_actor in recv_actors: futures.append(executor.submit(self.sync_send_recv, send_actor, recv_actor, requires_grad)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) else: for send_actor, recv_actors in self.send_recv_actor_mappings.items():