Skip to content

Commit 522b853

Browse files
committed
PERF: PtychographyUniformBatchSampler builds index only on rank 0
1 parent 5d37add commit 522b853

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

src/ptychi/io_handles.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from ptychi.device import AcceleratorModuleWrapper
1818
from ptychi.utils import to_tensor, to_numpy
1919
import ptychi.maths as pmath
20+
from ptychi.parallel import get_rank, get_world_size
2021

2122
logger = logging.getLogger(__name__)
2223

@@ -298,14 +299,44 @@ def __init__(self, positions, batch_size, *args, **kwargs):
298299
self.positions = positions
299300
self.batch_size = batch_size
300301

301-
self.build_indices()
302+
self.build_or_sync_indices()
302303

303304
def __len__(self):
304305
return math.ceil(len(self.positions) / self.batch_size)
305306

306307
def __iter__(self):
307308
for i in np.random.choice(range(len(self)), len(self), replace=False):
308309
yield self.batches_of_indices[i]
310+
311+
def build_or_sync_indices(self):
312+
if get_rank() == 0:
313+
self.build_indices()
314+
315+
if get_world_size() > 1:
316+
# Temporarily move indices to GPU.
317+
if get_rank() == 0:
318+
batch_lengths = torch.tensor(
319+
[len(batch) for batch in self.batches_of_indices], device=torch.get_default_device(), dtype=torch.long
320+
)
321+
flat_indices = torch.cat(self.batches_of_indices)
322+
else:
323+
batch_lengths = torch.empty(len(self), dtype=torch.long, device=torch.get_default_device())
324+
flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=torch.get_default_device())
325+
326+
torch.distributed.broadcast(batch_lengths, src=0)
327+
torch.distributed.broadcast(flat_indices, src=0)
328+
batch_lengths = batch_lengths.to(self.positions.device)
329+
flat_indices = flat_indices.to(self.positions.device)
330+
331+
# Re-assemble batch index list.
332+
if get_rank() != 0:
333+
batches = []
334+
start = 0
335+
for length in batch_lengths.tolist():
336+
end = start + length
337+
batches.append(flat_indices[start:end].clone())
338+
start = end
339+
self.batches_of_indices = tuple(batches)
309340

310341
def build_indices(self):
311342
dist_mat = torch.cdist(self.positions, self.positions, p=2)

src/ptychi/parallel.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,30 @@
1313
from ptychi.utils import to_tensor
1414

1515

16+
def get_rank():
17+
try:
18+
return dist.get_rank()
19+
except ValueError:
20+
return 0
21+
22+
23+
def get_world_size():
24+
try:
25+
return dist.get_world_size()
26+
except ValueError:
27+
return 1
28+
29+
1630
class MultiprocessMixin:
1731
backend = "nccl"
1832

1933
@property
2034
def rank(self) -> int:
21-
try:
22-
return dist.get_rank()
23-
except ValueError:
24-
return 0
35+
return get_rank()
2536

2637
@property
2738
def n_ranks(self) -> int:
28-
try:
29-
return dist.get_world_size()
30-
except ValueError:
31-
return 1
39+
return get_world_size()
3240

3341
def get_chunk_of_current_rank(
3442
self,

0 commit comments

Comments
 (0)