Skip to content

Commit

Permalink
Added all_gather_and_cat
Browse files Browse the repository at this point in the history
  • Loading branch information
BloodAxe committed Sep 16, 2024
1 parent 35ef9ce commit 991ca0b
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions pytorch_toolbelt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,17 +197,13 @@ def all_gather(data: Any) -> List[Any]:


def all_gather_and_cat(data: Any, dim=0) -> Any:
if not torch.is_tensor(data):
raise RuntimeError(f"Input data must be torch.Tensor, got {type(data)}")

device = data.device
data = all_gather(data)
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], np.ndarray):
return np.concatenate(data, axis=dim)
elif isinstance(data[0], list):
return [item for sublist in data for item in sublist]
else:
raise RuntimeError(
f"Unsupported data type {type(data[0])}. Input data must be list of torch.Tensor, np.ndarray or list"
)
data = [x.to(device) for x in data]
return torch.cat(data, dim=0)


def reduce_dict_sum(input_dict: Dict[Any, Any]) -> Dict[Any, Any]:
Expand Down

0 comments on commit 991ca0b

Please sign in to comment.