|
5 | 5 | from torch_xla._internal import rendezvous
|
6 | 6 | import logging
|
7 | 7 | import os
|
8 |
| -from torch._C._distributed_c10d import ProcessGroup |
| 8 | +from torch._C._distributed_c10d import ProcessGroup, AllgatherOptions |
9 | 9 |
|
10 | 10 |
|
11 | 11 | def _create_xla_process_group(prefix_store, rank, size, timeout):
|
@@ -81,16 +81,37 @@ def allreduce(self, tensors, all_reduce_options):
|
81 | 81 | xm.all_reduce(reduce_type, tensors, groups=self._mesh, pin_layout=False)
|
82 | 82 | return _ret_work(tensors)
|
83 | 83 |
|
84 |
| - # method for dist.all_gather_into_tensor under eager mode. |
85 |
| - def _allgather_base(self, output_tensor, input_tensor, opts): |
86 |
| - return self.allgather(output_tensor, input_tensor, opts) |
| 84 | + # This method is called for dist.all_gather_into_tensor under eager mode. |
| 85 | + # https://docs.pytorch.org/docs/stable/distributed.html#torch.distributed.all_gather_into_tensor |
| 86 | + def _allgather_base(self, output_tensor: torch.Tensor, |
| 87 | + input_tensor: torch.Tensor, opts: AllgatherOptions): |
| 88 | + is_scalar = (input_tensor.dim() == 0) |
| 89 | + if is_scalar: |
| 90 | + input_tensor = torch.reshape(input_tensor, (1,)) |
| 91 | + |
| 92 | + result = xm.all_gather( |
| 93 | + input_tensor, dim=0, groups=self._mesh, pin_layout=False) |
| 94 | + |
| 95 | + if result.shape == output_tensor.shape: |
| 96 | + output_tensor.copy_(result, non_blocking=True) |
| 97 | + return _ret_work([output_tensor]) |
| 98 | + |
| 99 | + stacked_result = torch.stack( |
| 100 | + torch.split(result, input_tensor.shape[0], dim=0), dim=0) |
| 101 | + if stacked_result.shape == output_tensor.shape: |
| 102 | + output_tensor.copy_(stacked_result, non_blocking=True) |
| 103 | + return _ret_work([output_tensor]) |
| 104 | + |
| 105 | + msg = f"Input shape {input_tensor.shape} and output shape {output_tensor.shape} are not compatible for all_gather_into_tensor. Input must be stacked or concatenated to create output." |
| 106 | + raise ValueError(msg) |
87 | 107 |
|
88 | 108 | def allgather(self, output_tensors_list, input_tensors, opts=None):
|
89 | 109 | for input_tensor, output_tensors in zip(input_tensors, output_tensors_list):
|
90 | 110 | is_scalar = (input_tensor.dim() == 0)
|
91 | 111 | if is_scalar:
|
92 | 112 | input_tensor = torch.reshape(input_tensor, (1,))
|
93 |
| - result = xm.all_gather(input_tensor, groups=self._mesh, pin_layout=False) |
| 113 | + result = xm.all_gather( |
| 114 | + input_tensor, dim=0, groups=self._mesh, pin_layout=False) |
94 | 115 | for i, slice in enumerate(torch.split(result, input_tensor.shape[0])):
|
95 | 116 | with torch.no_grad():
|
96 | 117 | output_tensors[i].copy_(
|
|
0 commit comments