Skip to content

Commit 9b0b02f

Browse files
authored
fix all_gather_into_tensor test and logic (#9332)
1 parent 0be0fbd commit 9b0b02f

File tree

2 files changed

+45
-14
lines changed

2 files changed

+45
-14
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,25 @@ def callable(input):
167167
return input.cpu()
168168

169169
@staticmethod
170-
def _all_gather_into_tensor(use_dynamo: bool):
170+
def _all_gather_into_tensor(use_dynamo: bool, mode: str):
171171
met.clear_all()
172172

173173
def callable(output, input):
174-
dist.all_gather_into_tensor(output_tensor, input, None)
175-
return output_tensor
174+
dist.all_gather_into_tensor(output, input, None)
175+
return output
176176

177177
dist.init_process_group("xla", init_method='xla://')
178178
device = torch_xla.device()
179179
input = torch.tensor([xr.global_ordinal()],
180180
dtype=torch.float,
181181
device=device)
182-
output_tensor = torch.empty((1, xr.world_size()), device=device)
182+
if mode == "stack":
183+
output_tensor = torch.empty((xr.world_size(), 1), device=device)
184+
elif mode == "concat":
185+
output_tensor = torch.empty((xr.world_size(),), device=device)
186+
else:
187+
raise ValueError(f"mode must be either 'stack' or 'concat'")
188+
183189
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
184190
f(output_tensor, input)
185191
torch_xla.sync()
@@ -278,13 +284,17 @@ def test_all_reduce(self, use_dynamo):
278284
for index, val in results.items():
279285
torch.testing.assert_close(val, expected)
280286

281-
@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
282-
def test_all_gather_into_tensor(self, use_dynamo):
287+
@parameterized.product(dynamo=[True, False], mode=["stack", "concat"])
288+
def test_all_gather_into_tensor(self, dynamo, mode):
289+
if dynamo and mode == "stack":
290+
self.skipTest("https://github.com/pytorch/pytorch/issues/155632")
283291
results = pjrt.run_multiprocess(
284-
self._all_gather_into_tensor, use_dynamo=use_dynamo)
292+
self._all_gather_into_tensor, use_dynamo=dynamo, mode=mode)
285293
expected = torch.arange(
286-
tpu.num_expected_global_devices(), dtype=torch.float).unsqueeze(0)
287-
for index, val in results.items():
294+
tpu.num_expected_global_devices(), dtype=torch.float)
295+
if mode == "stack":
296+
expected = expected.unsqueeze(1)
297+
for _, val in results.items():
288298
torch.testing.assert_close(val, expected)
289299

290300
@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))

torch_xla/distributed/xla_backend.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_xla._internal import rendezvous
66
import logging
77
import os
8-
from torch._C._distributed_c10d import ProcessGroup
8+
from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions
99

1010

1111
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -81,16 +81,37 @@ def allreduce(self, tensors, all_reduce_options):
8181
xm.all_reduce(reduce_type, tensors, groups=self._mesh, pin_layout=False)
8282
return _ret_work(tensors)
8383

84-
# method for dist.all_gather_into_tensor under eager mode.
85-
def _allgather_base(self, output_tensor, input_tensor, opts):
86-
return self.allgather(output_tensor, input_tensor, opts)
84+
# This method is called for dist.all_gather_into_tensor under eager mode.
85+
# https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_into_tensor
86+
def _allgather_base(self, output_tensor: torch.Tensor,
87+
input_tensor: torch.Tensor, opts: AllgatherOptions):
88+
is_scalar = (input_tensor.dim() == 0)
89+
if is_scalar:
90+
input_tensor = torch.reshape(input_tensor, (1,))
91+
92+
result = xm.all_gather(
93+
input_tensor, dim=0, groups=self._mesh, pin_layout=False)
94+
95+
if result.shape == output_tensor.shape:
96+
output_tensor.copy_(result, non_blocking=True)
97+
return _ret_work([output_tensor])
98+
99+
stacked_result = torch.stack(
100+
torch.split(result, input_tensor.shape[0], dim=0), dim=0)
101+
if stacked_result.shape == output_tensor.shape:
102+
output_tensor.copy_(stacked_result, non_blocking=True)
103+
return _ret_work([output_tensor])
104+
105+
msg = f"Input shape {input_tensor.shape} and output shape {output_tensor.shape} are not compatible for all_gather_into_tensor. Input must be stacked or concatenated to create output."
106+
raise ValueError(msg)
87107

88108
def allgather(self, output_tensors_list, input_tensors, opts=None):
89109
for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
90110
is_scalar = (input_tensor.dim() == 0)
91111
if is_scalar:
92112
input_tensor = torch.reshape(input_tensor, (1,))
93-
result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False)
113+
result = xm.all_gather(
114+
input_tensor, dim=0, groups=self._mesh, pin_layout=False)
94115
for i, slice in enumerate(torch.split(result, input_tensor.shape[0])):
95116
with torch.no_grad():
96117
output_tensors[i].copy_(

0 commit comments

Comments
 (0)