Skip to content

Commit 07f0483

Browse files
authored
Add support for custom collate with the StreamingDataLoader (#163)
1 parent 4b210f5 commit 07f0483

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

src/litdata/streaming/dataloader.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -476,6 +476,23 @@ def _try_put_index(self) -> None:
476476
super()._try_put_index()
477477

478478

479+
class StreamingDataLoaderCollateFn:
480+
def __init__(self, collate_fn: Optional[Callable] = None) -> None:
481+
self.collate_fn = collate_fn or default_collate
482+
483+
def __call__(self, items: List[Any]) -> Any:
484+
if len(items) > 0 and isinstance(items[0], dict) and __NUM_SAMPLES_YIELDED_KEY__ in items[0]:
485+
batch = self.collate_fn([item[__SAMPLES_KEY__] for item in items])
486+
return {
487+
__SAMPLES_KEY__: batch,
488+
__NUM_SAMPLES_YIELDED_KEY__: [
489+
torch.cumsum([torch.tensor(item[__NUM_SAMPLES_YIELDED_KEY__]) for item in items][-1], dim=0)
490+
],
491+
}
492+
493+
return self.collate_fn(items)
494+
495+
479496
class StreamingDataLoader(DataLoader):
480497
r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
481498
@@ -541,6 +558,7 @@ def __init__(
541558
prefetch_factor: Optional[int] = None,
542559
shuffle: Optional[bool] = None,
543560
drop_last: Optional[bool] = False,
561+
collate_fn: Optional[Callable] = None,
544562
**kwargs: Any,
545563
) -> None: # pyright: ignore
546564
if not isinstance(dataset, (StreamingDataset, CombinedStreamingDataset)):
@@ -563,6 +581,9 @@ def __init__(
563581
if profile_batches and num_workers == 0:
564582
raise ValueError("Profiling is supported only with num_workers >= 1.")
565583

584+
if collate_fn:
585+
collate_fn = StreamingDataLoaderCollateFn(collate_fn)
586+
566587
self.current_epoch = 0
567588
self.batch_size = batch_size
568589
self.num_workers = num_workers
@@ -581,6 +602,7 @@ def __init__(
581602
batch_size=batch_size,
582603
num_workers=num_workers,
583604
prefetch_factor=(10 if num_workers > 0 else None) if prefetch_factor is None else prefetch_factor,
605+
collate_fn=collate_fn,
584606
**kwargs,
585607
) # type: ignore
586608

tests/streaming/test_dataloader.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,31 @@ def test_dataloader_shuffle():
119119
StreamingDataLoader(dataset, batch_size=2, num_workers=1, shuffle=True)
120120
assert dataset._datasets[0].shuffle
121121
assert dataset._datasets[1].shuffle
122+
123+
124+
class TestStatefulDatasetDict(TestStatefulDataset):
125+
def __next__(self):
126+
return {"value": super().__next__()}
127+
128+
129+
def custom_collate_fn(samples):
130+
assert len(samples) == 2
131+
assert "value" in samples[0]
132+
return "received"
133+
134+
135+
def test_custom_collate():
136+
dataset = TestCombinedStreamingDataset(
137+
[TestStatefulDatasetDict(10, 1), TestStatefulDatasetDict(10, -1)],
138+
42,
139+
weights=(0.5, 0.5),
140+
iterate_over_all=False,
141+
)
142+
assert dataset._datasets[0].shuffle is None
143+
assert dataset._datasets[1].shuffle is None
144+
dataloader = StreamingDataLoader(dataset, batch_size=2, num_workers=0, shuffle=True, collate_fn=custom_collate_fn)
145+
assert dataset._datasets[0].shuffle
146+
assert dataset._datasets[1].shuffle
147+
dataloader_iter = iter(dataloader)
148+
assert next(dataloader_iter) == "received"
149+
assert dataloader._num_samples_yielded_combined[0] == [2]

0 commit comments

Comments
 (0)