|
4 | 4 | from typing import Optional, Union |
5 | 5 | import logging |
6 | 6 | import math |
| 7 | +import os |
7 | 8 |
|
8 | 9 | import torch |
9 | 10 | import torch.utils |
|
17 | 18 | from ptychi.device import AcceleratorModuleWrapper |
18 | 19 | from ptychi.utils import to_tensor, to_numpy |
19 | 20 | import ptychi.maths as pmath |
| 21 | +from ptychi.parallel import get_rank, get_world_size |
20 | 22 |
|
21 | 23 | logger = logging.getLogger(__name__) |
22 | 24 |
|
@@ -298,14 +300,56 @@ def __init__(self, positions, batch_size, *args, **kwargs): |
298 | 300 | self.positions = positions |
299 | 301 | self.batch_size = batch_size |
300 | 302 |
|
301 | | - self.build_indices() |
| 303 | + self.build_or_sync_indices() |
302 | 304 |
|
303 | 305 | def __len__(self): |
304 | 306 | return math.ceil(len(self.positions) / self.batch_size) |
305 | 307 |
|
306 | 308 | def __iter__(self): |
307 | 309 | for i in np.random.choice(range(len(self)), len(self), replace=False): |
308 | 310 | yield self.batches_of_indices[i] |
| 311 | + |
| 312 | + def check_omp_num_threads(self): |
| 313 | + if get_world_size() == 1: |
| 314 | + return |
| 315 | + val = os.environ.get("OMP_NUM_THREADS", "unset") |
| 316 | + if not (val != "unset" and int(val) > 1): |
| 317 | + logging.warning( |
| 318 | + f"You are using multi-processing but OMP_NUM_THREADS is {val}. " |
| 319 | + f"Index building in uniform batching mode may be slower than expected. " |
| 320 | + f"Set OMP_NUM_THREADS to a value greater than 1 to improve performance." |
| 321 | + ) |
| 322 | + |
| 323 | + def build_or_sync_indices(self): |
| 324 | + self.check_omp_num_threads() |
| 325 | + if get_rank() == 0: |
| 326 | + self.build_indices() |
| 327 | + |
| 328 | + if get_world_size() > 1: |
| 329 | + # Temporarily move indices to GPU. |
| 330 | + if get_rank() == 0: |
| 331 | + batch_lengths = torch.tensor( |
| 332 | + [len(batch) for batch in self.batches_of_indices], device=torch.get_default_device(), dtype=torch.long |
| 333 | + ) |
| 334 | + flat_indices = torch.cat(self.batches_of_indices).to(torch.get_default_device()) |
| 335 | + else: |
| 336 | + batch_lengths = torch.empty(len(self), dtype=torch.long, device=torch.get_default_device()) |
| 337 | + flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=torch.get_default_device()) |
| 338 | + |
| 339 | + torch.distributed.broadcast(batch_lengths, src=0) |
| 340 | + torch.distributed.broadcast(flat_indices, src=0) |
| 341 | + batch_lengths = batch_lengths.to(self.positions.device) |
| 342 | + flat_indices = flat_indices.to(self.positions.device) |
| 343 | + |
| 344 | + # Re-assemble batch index list. |
| 345 | + if get_rank() != 0: |
| 346 | + batches = [] |
| 347 | + start = 0 |
| 348 | + for length in batch_lengths.tolist(): |
| 349 | + end = start + length |
| 350 | + batches.append(flat_indices[start:end].clone()) |
| 351 | + start = end |
| 352 | + self.batches_of_indices = tuple(batches) |
309 | 353 |
|
310 | 354 | def build_indices(self): |
311 | 355 | dist_mat = torch.cdist(self.positions, self.positions, p=2) |
|
0 commit comments