Skip to content

Commit 6fb0ed5

Browse files
committed
PERF: PtychographyUniformBatchSampler builds index on GPU when available
1 parent fbcb07c commit 6fb0ed5

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

src/ptychi/io_handles.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -309,37 +309,26 @@ def __iter__(self):
309309
for i in np.random.choice(range(len(self)), len(self), replace=False):
310310
yield self.batches_of_indices[i]
311311

312-
def check_omp_num_threads(self):
313-
if get_world_size() == 1:
314-
return
315-
val = os.environ.get("OMP_NUM_THREADS", "unset")
316-
if not (val != "unset" and int(val) > 1):
317-
logging.warning(
318-
f"You are using multi-processing but OMP_NUM_THREADS is {val}. "
319-
f"Index building in uniform batching mode may be slower than expected. "
320-
f"Set OMP_NUM_THREADS to a value greater than 1 to improve performance."
321-
)
322-
323312
def build_or_sync_indices(self):
324-
self.check_omp_num_threads()
313+
orig_device = self.positions.device
314+
self.positions = self.positions.to(torch.get_default_device())
315+
325316
if get_rank() == 0:
326317
self.build_indices()
327318

328319
if get_world_size() > 1:
329320
# Temporarily move indices to GPU.
330321
if get_rank() == 0:
331322
batch_lengths = torch.tensor(
332-
[len(batch) for batch in self.batches_of_indices], device=torch.get_default_device(), dtype=torch.long
323+
[len(batch) for batch in self.batches_of_indices], device=self.positions.device, dtype=torch.long
333324
)
334-
flat_indices = torch.cat(self.batches_of_indices).to(torch.get_default_device())
325+
flat_indices = torch.cat(self.batches_of_indices).to(self.positions.device)
335326
else:
336-
batch_lengths = torch.empty(len(self), dtype=torch.long, device=torch.get_default_device())
337-
flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=torch.get_default_device())
327+
batch_lengths = torch.empty(len(self), dtype=torch.long, device=self.positions.device)
328+
flat_indices = torch.empty(len(self.positions), dtype=torch.long, device=self.positions.device)
338329

339330
torch.distributed.broadcast(batch_lengths, src=0)
340331
torch.distributed.broadcast(flat_indices, src=0)
341-
batch_lengths = batch_lengths.to(self.positions.device)
342-
flat_indices = flat_indices.to(self.positions.device)
343332

344333
# Re-assemble batch index list.
345334
if get_rank() != 0:
@@ -350,6 +339,9 @@ def build_or_sync_indices(self):
350339
batches.append(flat_indices[start:end].clone())
351340
start = end
352341
self.batches_of_indices = tuple(batches)
342+
343+
# Move back to original device.
344+
self.batches_of_indices = [x.to(orig_device) for x in self.batches_of_indices]
353345

354346
def build_indices(self):
355347
dist_mat = torch.cdist(self.positions, self.positions, p=2)

0 commit comments

Comments
 (0)