Skip to content

Commit eedb32b

Browse files
committed
FIX: in chunked_processing, replicate replicable tensors on-the-fly instead of replicating them in full length before the function to reduce memory usage
1 parent 4955f43 commit eedb32b

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

src/ptychi/data_structures/object.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,6 @@ def calculate_illumination_map(
561561
probe_int = probe.get_all_mode_intensity(opr_mode=0)[None, :, :]
562562
else:
563563
probe_int = probe.get_mode_and_opr_mode(mode=0, opr_mode=0)[None, ...].abs() ** 2
564-
# Shape of probe_int: (n_scan_points, h, w)
565-
probe_int = probe_int.repeat(len(positions_all), 1, 1)
566564

567565
# Stitch probes of all positions on the object buffer
568566
# TODO: allow setting chunk size externally
@@ -571,11 +569,13 @@ def calculate_illumination_map(
571569
common_kwargs={"op": "add"},
572570
chunkable_kwargs={
573571
"positions": positions_all.round().int() + self.pos_origin_coords,
574-
"patches": probe_int,
575572
},
576573
iterated_kwargs={
577574
"image": torch.zeros_like(object_.real).type(torch.get_default_dtype())
578575
},
576+
replicated_kwargs={
577+
"patches": probe_int,
578+
},
579579
chunk_size=64,
580580
)
581581
return probe_sq_map

src/ptychi/utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def chunked_processing(
365365
common_kwargs: dict,
366366
chunkable_kwargs: dict,
367367
iterated_kwargs: dict,
368+
replicated_kwargs: dict = None,
368369
chunk_size: int = 96,
369370
):
370371
"""
@@ -380,6 +381,10 @@ def chunked_processing(
380381
A dictionary of arguments that should be returned by `func`, then passed to `func`
381382
for the next chunk. The order of arguments should be the same as the returns of
382383
`func`.
384+
replicated_kwargs : dict, optional
385+
A dictionary of arguments that should be replicated for each chunk along the
386+
first dimension to match the chunk size. Tensors given here should have a first
387+
dimension of size 1 intended as the batch dimension.
383388
chunk_size : int, optional
384389
The size of each chunk. Default is 96.
385390
@@ -404,6 +409,13 @@ def chunked_processing(
404409
ind_st = ind_end
405410

406411
for kwargs_chunk in chunks_of_chunkable_args:
412+
current_chunk_size = kwargs_chunk[list(kwargs_chunk.keys())[0]].shape[0]
413+
if replicated_kwargs is not None:
414+
replicated_kwargs_chunk = {
415+
key: torch.repeat_interleave(value, current_chunk_size, dim=0)
416+
for key, value in replicated_kwargs.items()
417+
}
418+
kwargs_chunk.update(replicated_kwargs_chunk)
407419
ret = func(**common_kwargs, **kwargs_chunk, **iterated_kwargs)
408420
if isinstance(ret, tuple):
409421
for i, key in enumerate(iterated_kwargs.keys()):

0 commit comments

Comments
 (0)