|
17 | 17 | from ptychi.device import AcceleratorModuleWrapper |
18 | 18 | from ptychi.utils import to_tensor, to_numpy |
19 | 19 | import ptychi.maths as pmath |
| 20 | +from ptychi.parallel import get_rank, get_world_size |
20 | 21 |
|
21 | 22 | logger = logging.getLogger(__name__) |
22 | 23 |
|
@@ -298,14 +299,44 @@ def __init__(self, positions, batch_size, *args, **kwargs): |
298 | 299 | self.positions = positions |
299 | 300 | self.batch_size = batch_size |
300 | 301 |
|
301 | | - self.build_indices() |
| 302 | + self.build_or_sync_indices() |
302 | 303 |
|
303 | 304 | def __len__(self): |
304 | 305 | return math.ceil(len(self.positions) / self.batch_size) |
305 | 306 |
|
306 | 307 | def __iter__(self): |
307 | 308 | for i in np.random.choice(range(len(self)), len(self), replace=False): |
308 | 309 | yield self.batches_of_indices[i] |
| 310 | + |
| 311 | + def build_or_sync_indices(self): |
| 312 | + if get_rank() == 0: |
| 313 | + self.build_indices() |
| 314 | + |
| 315 | + if get_world_size() > 1: |
| 316 | + # Temporarily move indices to GPU. |
| 317 | + if get_rank() == 0: |
| 318 | + batch_lengths = torch.tensor( |
| 319 | + [len(batch) for batch in self.batches_of_indices], device=torch.get_default_device(), dtype=torch.long |
| 320 | + ) |
| 321 | + flat_indices = torch.cat(self.batches_of_indices) |
| 322 | + else: |
| 323 | + batch_lengths = torch.empty(len(self), dtype=torch.long, device=torch.get_default_device()) |
| 324 | + flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=torch.get_default_device()) |
| 325 | + |
| 326 | + torch.distributed.broadcast(batch_lengths, src=0) |
| 327 | + torch.distributed.broadcast(flat_indices, src=0) |
| 328 | + batch_lengths = batch_lengths.to(self.positions.device) |
| 329 | + flat_indices = flat_indices.to(self.positions.device) |
| 330 | + |
| 331 | + # Re-assemble batch index list. |
| 332 | + if get_rank() != 0: |
| 333 | + batches = [] |
| 334 | + start = 0 |
| 335 | + for length in batch_lengths.tolist(): |
| 336 | + end = start + length |
| 337 | + batches.append(flat_indices[start:end].clone()) |
| 338 | + start = end |
| 339 | + self.batches_of_indices = tuple(batches) |
309 | 340 |
|
310 | 341 | def build_indices(self): |
311 | 342 | dist_mat = torch.cdist(self.positions, self.positions, p=2) |
|
0 commit comments