Skip to content

Commit c7acbbb

Browse files
fix: check payload before accumulating bytes received (#430)
Aggregators sometimes send empty messages at the end of training. Previously, the trainer would throw an error when attempting to get the number of bytes received from an empty message. This is fixed by checking for an empty payload.
1 parent f1d9026 commit c7acbbb

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lib/python/flame/channel.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,9 @@ async def _get():
229229
payload = None
230230
try:
231231
payload = await self._ends[end_id].get()
232-
# ignore timestamp for measuring bytes received
233-
self.mc.accumulate("bytes", "recv", len(payload[0]))
232+
if payload:
233+
# ignore timestamp for measuring bytes received
234+
self.mc.accumulate("bytes", "recv", len(payload[0]))
234235
except KeyError:
235236
return None
236237

@@ -345,8 +346,9 @@ async def _get_inner(end_id) -> tuple[str, Any]:
345346
payload = None
346347
try:
347348
payload = await self._ends[end_id].get()
348-
# ignore timestamp for measuring bytes received
349-
self.mc.accumulate("bytes", "recv", len(payload[0]))
349+
if payload:
350+
# ignore timestamp for measuring bytes received
351+
self.mc.accumulate("bytes", "recv", len(payload[0]))
350352
except KeyError:
351353
yield end_id, None
352354

0 commit comments

Comments
 (0)