@@ -476,6 +476,23 @@ def _try_put_index(self) -> None:
476
476
super ()._try_put_index ()
477
477
478
478
479
+ class StreamingDataLoaderCollateFn :
480
+ def __init__ (self , collate_fn : Optional [Callable ] = None ) -> None :
481
+ self .collate_fn = collate_fn or default_collate
482
+
483
+ def __call__ (self , items : List [Any ]) -> Any :
484
+ if len (items ) > 0 and isinstance (items [0 ], dict ) and __NUM_SAMPLES_YIELDED_KEY__ in items [0 ]:
485
+ batch = self .collate_fn ([item [__SAMPLES_KEY__ ] for item in items ])
486
+ return {
487
+ __SAMPLES_KEY__ : batch ,
488
+ __NUM_SAMPLES_YIELDED_KEY__ : [
489
+ torch .cumsum ([torch .tensor (item [__NUM_SAMPLES_YIELDED_KEY__ ]) for item in items ][- 1 ], dim = 0 )
490
+ ],
491
+ }
492
+
493
+ return self .collate_fn (items )
494
+
495
+
479
496
class StreamingDataLoader (DataLoader ):
480
497
r"""The StreamingDataLoader combines a dataset and a sampler, and provides an iterable over the given dataset.
481
498
@@ -541,6 +558,7 @@ def __init__(
541
558
prefetch_factor : Optional [int ] = None ,
542
559
shuffle : Optional [bool ] = None ,
543
560
drop_last : Optional [bool ] = False ,
561
+ collate_fn : Optional [Callable ] = None ,
544
562
** kwargs : Any ,
545
563
) -> None : # pyright: ignore
546
564
if not isinstance (dataset , (StreamingDataset , CombinedStreamingDataset )):
@@ -563,6 +581,9 @@ def __init__(
563
581
if profile_batches and num_workers == 0 :
564
582
raise ValueError ("Profiling is supported only with num_workers >= 1." )
565
583
584
+ if collate_fn :
585
+ collate_fn = StreamingDataLoaderCollateFn (collate_fn )
586
+
566
587
self .current_epoch = 0
567
588
self .batch_size = batch_size
568
589
self .num_workers = num_workers
@@ -581,6 +602,7 @@ def __init__(
581
602
batch_size = batch_size ,
582
603
num_workers = num_workers ,
583
604
prefetch_factor = (10 if num_workers > 0 else None ) if prefetch_factor is None else prefetch_factor ,
605
+ collate_fn = collate_fn ,
584
606
** kwargs ,
585
607
) # type: ignore
586
608
0 commit comments