From 5423b45f99b8b0211dcbbaf1ef609fb31b69cc0a Mon Sep 17 00:00:00 2001 From: "Le, Jiang" <55124727+charles9304@users.noreply.github.com> Date: Fri, 30 Aug 2024 13:34:30 +0800 Subject: [PATCH] raise error if parameter sync breaks (#42) --- chatlearn/runtime/parameter_sync.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/chatlearn/runtime/parameter_sync.py b/chatlearn/runtime/parameter_sync.py index 1d494d51..3b0e96c6 100644 --- a/chatlearn/runtime/parameter_sync.py +++ b/chatlearn/runtime/parameter_sync.py @@ -232,7 +232,7 @@ def sync_broadcast(self, actors, group_name, requires_grad=None): for rank, actor in enumerate(actors): ref = actor.broadcast_parameter.remote(rank, 0, group_name, pipe_stage) refs.append(ref) - future.wait(refs) + future.wait(refs, return_output=True) def _sync_send_recv(self, send_actor, recv_actor, requires_grad=None): @@ -429,6 +429,11 @@ def sync(self, requires_grad=None): else: for recv_actor in recv_actors: futures.append(executor.submit(self.sync_send_recv, send_actor, recv_actor, requires_grad)) + for _future in concurrent.futures.as_completed(futures): + try: + _future.result() + except Exception as e: + raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from concurrent.futures.wait(futures) else: for send_actor, recv_actors in self.send_recv_actor_mappings.items():