Skip to content

Commit

Permalink
raise error if parameter sync breaks (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Aug 30, 2024
1 parent 5393709 commit 5423b45
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion chatlearn/runtime/parameter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def sync_broadcast(self, actors, group_name, requires_grad=None):
for rank, actor in enumerate(actors):
ref = actor.broadcast_parameter.remote(rank, 0, group_name, pipe_stage)
refs.append(ref)
future.wait(refs)
future.wait(refs, return_output=True)


def _sync_send_recv(self, send_actor, recv_actor, requires_grad=None):
Expand Down Expand Up @@ -429,6 +429,11 @@ def sync(self, requires_grad=None):
else:
for recv_actor in recv_actors:
futures.append(executor.submit(self.sync_send_recv, send_actor, recv_actor, requires_grad))
for _future in concurrent.futures.as_completed(futures):
try:
_future.result()
except Exception as e:
raise RuntimeError(f"Parameter sync thread generated an exception: {e}") # pylint: disable=raise-missing-from
concurrent.futures.wait(futures)
else:
for send_actor, recv_actors in self.send_recv_actor_mappings.items():
Expand Down

0 comments on commit 5423b45

Please sign in to comment.