From 3b121151b552f7f9bfc1d023f6d4fdcda761df99 Mon Sep 17 00:00:00 2001 From: Yeounoh Chung Date: Tue, 8 Mar 2022 23:54:28 -0800 Subject: [PATCH] Merge xm all_gather patch (#3416) * Set proper shard_count for all_gather, when replica groups are non-empty. * Update test_mp_all_gather.py --- test/test_mp_all_gather.py | 32 ++++++++++++++++++++++++++------ torch_xla/core/xla_model.py | 8 +++++++- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/test/test_mp_all_gather.py b/test/test_mp_all_gather.py index 597ff079cec..6ffc854bf1b 100644 --- a/test/test_mp_all_gather.py +++ b/test/test_mp_all_gather.py @@ -8,21 +8,41 @@ def _mp_fn(index): device = xm.xla_device() + world_size = xm.xrt_world_size() if xm.xla_device_hw(device) in ('TPU', 'GPU'): - world_size = xm.xrt_world_size() + # Testing with a single replica group ordinal_tensor = torch.tensor([index], dtype=torch.float).to(device) - result = xm.all_gather(ordinal_tensor) + result = xm.all_gather(ordinal_tensor, dim=0) cpu_result = result.cpu() expected = torch.arange(0, world_size, dtype=torch.float) if not cpu_result.allclose(expected): print('xm.all_gather() produced wrong reductions', file=sys.stderr) - print('[{}] {}'.format(index, cpu_result), file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) sys.exit(1) + + # Testing with two replica groups + if world_size % 2 == 0 and world_size > 1: + mp_groups = [[n for n in range(world_size) if n % 2 == 0], + [n for n in range(world_size) if n % 2 == 1]] + group_size = len(mp_groups[0]) + replica_id = int(index % 2 == 1) + + result = xm.all_gather(ordinal_tensor, dim=0, groups=mp_groups) + + cpu_result = result.cpu() + expected = torch.arange(replica_id, world_size, step=2, dtype=torch.float) + if not cpu_result.allclose(expected): + print('xm.all_gather() produced wrong reductions', file=sys.stderr) + print(f'[{index}] {cpu_result}', file=sys.stderr) + sys.exit(1) + else: + print( + f'Failed to create two replica groups with {world_size} replicas', + file=sys.stderr) + else: - print( - 'Default device {} is not a TPU or GPU device'.format(device), - file=sys.stderr) + print(f'{device} is not a TPU or GPU device', file=sys.stderr) if __name__ == '__main__': diff --git a/torch_xla/core/xla_model.py b/torch_xla/core/xla_model.py index cb633cf60c5..413ed8fa270 100644 --- a/torch_xla/core/xla_model.py +++ b/torch_xla/core/xla_model.py @@ -603,7 +603,13 @@ def all_gather(value, dim=0, groups=None, output=None): if dim < 0: dim = value.dim() + dim token, devctx = _get_all_reduce_token() - shard_count = None if groups else xrt_world_size() + if groups: + shard_count = len(groups[0]) + assert all(len(group) == shard_count for group in groups), \ + "Replica groups must have the same number of replicas/shards." + else: + # All replicas belong to a single group + shard_count = xrt_world_size() if output != None: # Call the out of place version of the all_gather new_token = torch_xla._XLAC._xla_all_gather_out(output, value, token, dim,