Skip to content

Commit fbcb07c

Browse files
committed
PERF: PtychographyUniformBatchSampler builds index only on rank 0
1 parent 5d37add commit fbcb07c

File tree

2 files changed

+61
-9
lines changed

2 files changed

+61
-9
lines changed

src/ptychi/io_handles.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional, Union
55
import logging
66
import math
7+
import os
78

89
import torch
910
import torch.utils
@@ -17,6 +18,7 @@
1718
from ptychi.device import AcceleratorModuleWrapper
1819
from ptychi.utils import to_tensor, to_numpy
1920
import ptychi.maths as pmath
21+
from ptychi.parallel import get_rank, get_world_size
2022

2123
logger = logging.getLogger(__name__)
2224

@@ -298,14 +300,56 @@ def __init__(self, positions, batch_size, *args, **kwargs):
298300
self.positions = positions
299301
self.batch_size = batch_size
300302

301-
self.build_indices()
303+
self.build_or_sync_indices()
302304

303305
def __len__(self):
304306
return math.ceil(len(self.positions) / self.batch_size)
305307

306308
def __iter__(self):
307309
for i in np.random.choice(range(len(self)), len(self), replace=False):
308310
yield self.batches_of_indices[i]
311+
312+
def check_omp_num_threads(self):
313+
if get_world_size() == 1:
314+
return
315+
val = os.environ.get("OMP_NUM_THREADS", "unset")
316+
if not (val != "unset" and int(val) > 1):
317+
logging.warning(
318+
f"You are using multi-processing but OMP_NUM_THREADS is {val}. "
319+
f"Index building in uniform batching mode may be slower than expected. "
320+
f"Set OMP_NUM_THREADS to a value greater than 1 to improve performance."
321+
)
322+
323+
def build_or_sync_indices(self):
324+
self.check_omp_num_threads()
325+
if get_rank() == 0:
326+
self.build_indices()
327+
328+
if get_world_size() > 1:
329+
# Temporarily move indices to GPU.
330+
if get_rank() == 0:
331+
batch_lengths = torch.tensor(
332+
[len(batch) for batch in self.batches_of_indices], device=torch.get_default_device(), dtype=torch.long
333+
)
334+
flat_indices = torch.cat(self.batches_of_indices).to(torch.get_default_device())
335+
else:
336+
batch_lengths = torch.empty(len(self), dtype=torch.long, device=torch.get_default_device())
337+
flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=torch.get_default_device())
338+
339+
torch.distributed.broadcast(batch_lengths, src=0)
340+
torch.distributed.broadcast(flat_indices, src=0)
341+
batch_lengths = batch_lengths.to(self.positions.device)
342+
flat_indices = flat_indices.to(self.positions.device)
343+
344+
# Re-assemble batch index list.
345+
if get_rank() != 0:
346+
batches = []
347+
start = 0
348+
for length in batch_lengths.tolist():
349+
end = start + length
350+
batches.append(flat_indices[start:end].clone())
351+
start = end
352+
self.batches_of_indices = tuple(batches)
309353

310354
def build_indices(self):
311355
dist_mat = torch.cdist(self.positions, self.positions, p=2)

src/ptychi/parallel.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313
from ptychi.utils import to_tensor
1414

1515

16+
def get_rank():
17+
try:
18+
return dist.get_rank()
19+
except ValueError:
20+
return 0
21+
22+
23+
def get_world_size():
24+
try:
25+
return dist.get_world_size()
26+
except ValueError:
27+
return 1
28+
29+
1630
class MultiprocessMixin:
1731
backend = "nccl"
1832

1933
@property
2034
def rank(self) -> int:
21-
try:
22-
return dist.get_rank()
23-
except ValueError:
24-
return 0
35+
return get_rank()
2536

2637
@property
2738
def n_ranks(self) -> int:
28-
try:
29-
return dist.get_world_size()
30-
except ValueError:
31-
return 1
39+
return get_world_size()
3240

3341
def get_chunk_of_current_rank(
3442
self,

0 commit comments

Comments
 (0)