Skip to content

Commit aad8710

Browse files
committed
fix all_gather_into_tensor test and logic
1 parent d82e15c commit aad8710

File tree

2 files changed

+41
-13
lines changed

2 files changed

+41
-13
lines changed

test/pjrt/test_collective_ops_tpu.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,19 +167,25 @@ def callable(input):
167167
return input.cpu()
168168

169169
@staticmethod
170-
def _all_gather_into_tensor(use_dynamo: bool):
170+
def _all_gather_into_tensor(use_dynamo: bool, mode: str):
171171
met.clear_all()
172+
allowed_modes = ["stack", "concat"]
173+
if mode not in allowed_modes:
174+
raise ValueError(f"mode must be one of {allowed_modes}")
172175

173176
def callable(output, input):
174-
dist.all_gather_into_tensor(output_tensor, input, None)
175-
return output_tensor
177+
dist.all_gather_into_tensor(output, input, None)
178+
return output
176179

177180
dist.init_process_group("xla", init_method='xla://')
178181
device = torch_xla.device()
179182
input = torch.tensor([xr.global_ordinal()],
180183
dtype=torch.float,
181184
device=device)
182-
output_tensor = torch.empty((1, xr.world_size()), device=device)
185+
if mode == "stack":
186+
output_tensor = torch.empty((xr.world_size(), 1), device=device)
187+
elif mode == "concat":
188+
output_tensor = torch.empty((xr.world_size(),), device=device)
183189
f = torch.compile(callable, backend='openxla') if use_dynamo else callable
184190
f(output_tensor, input)
185191
torch_xla.sync()
@@ -278,13 +284,17 @@ def test_all_reduce(self, use_dynamo):
278284
for index, val in results.items():
279285
torch.testing.assert_close(val, expected)
280286

281-
@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))
282-
def test_all_gather_into_tensor(self, use_dynamo):
287+
@parameterized.product(dynamo=[True, False], mode=["stack", "concat"])
288+
def test_all_gather_into_tensor(self, dynamo, mode):
289+
if dynamo and mode == "stack":
290+
self.skipTest("https://github.com/pytorch/pytorch/issues/155632")
283291
results = pjrt.run_multiprocess(
284-
self._all_gather_into_tensor, use_dynamo=use_dynamo)
292+
self._all_gather_into_tensor, use_dynamo=dynamo, mode=mode)
285293
expected = torch.arange(
286-
tpu.num_expected_global_devices(), dtype=torch.float).unsqueeze(0)
287-
for index, val in results.items():
294+
tpu.num_expected_global_devices(), dtype=torch.float)
295+
if mode == "stack":
296+
expected = expected.unsqueeze(1)
297+
for _, val in results.items():
288298
torch.testing.assert_close(val, expected)
289299

290300
@parameterized.named_parameters(('dynamo', True), ('nondynamo', False))

torch_xla/distributed/xla_backend.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from torch_xla._internal import rendezvous
66
import logging
77
import os
8-
from torch._C._distributed_c10d import ProcessGroup
8+
from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions
99

1010

1111
def _create_xla_process_group(prefix_store, rank, size, timeout):
@@ -82,15 +82,33 @@ def allreduce(self, tensors, all_reduce_options):
8282
return _ret_work(tensors)
8383

8484
# method for dist.all_gather_into_tensor under eager mode.
85-
def _allgather_base(self, output_tensor, input_tensor, opts):
86-
return self.allgather(output_tensor, input_tensor, opts)
85+
def _allgather_base(self, output_tensor: torch.Tensor,
86+
input_tensor: torch.Tensor, opts: AllgatherOptions):
87+
is_scalar = (input_tensor.dim() == 0)
88+
if is_scalar:
89+
input_tensor = torch.reshape(input_tensor, (1,))
90+
result = xm.all_gather(
91+
input_tensor, dim=0, groups=self._mesh, pin_layout=False)
92+
if result.shape == output_tensor.shape:
93+
output_tensor.copy_(result, non_blocking=True)
94+
else:
95+
stacked_result = torch.stack(
96+
torch.split(result, input_tensor.shape[0], dim=0), dim=0)
97+
if stacked_result.shape == output_tensor.shape:
98+
output_tensor.copy_(stacked_result, non_blocking=True)
99+
else:
100+
msg = f"Input shape {input_tensor.shape} and output shape {output_tensor.shape} are not compatible for all_gather_into_tensor. Input must be stacked or concatenated to create output."
101+
raise ValueError(msg)
102+
103+
return _ret_work([output_tensor])
87104

88105
def allgather(self, output_tensors_list, input_tensors, opts=None):
89106
for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
90107
is_scalar = (input_tensor.dim() == 0)
91108
if is_scalar:
92109
input_tensor = torch.reshape(input_tensor, (1,))
93-
result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False)
110+
result = xm.all_gather(
111+
input_tensor, dim=0, groups=self._mesh, pin_layout=False)
94112
for i, slice in enumerate(torch.split(result, input_tensor.shape[0])):
95113
with torch.no_grad():
96114
output_tensors[i].copy_(

0 commit comments

Comments
 (0)