-
Notifications
You must be signed in to change notification settings - Fork 80
Description
🚀 Feature
Support for subclasses of StreamingDataset which yield a different number of samples compared to the original dataset.
Motivation
There are cases where the user might want to subclass StreamingDataset to create a dataset which yields a different number of samples compared to the default implementation. For example in #552, I read samples with variable length and queue them. Once the total length is long enough, I pop from the queue and yield until the total length is not long enough, in which case I load next sample. Another example could be skipping some samples, e.g. if they have bad quality.
While the example in #552 works at first glance, it actually breaks key features because the internal sample counters are increased for each sample yielded, regardless of how many items were actually read from the dataset. For example in StreamingDataLoader:
litData/src/litdata/streaming/dataloader.py
Line 664 in 6ceef3e
| self._num_samples_yielded_streaming += self.batch_size |
Or in ParallelStreamingDataset:
litData/src/litdata/streaming/parallel.py
Line 349 in 6ceef3e
| self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1 |
This naturally messes up state saving and loading as the internal dataset is resumed at an offset which does not match the number of samples actually read from the dataset.
Another issue is poor support for overriding __len__. Users should be allowed to update __len__ to return the new number of samples the dataset yields, or at the very least to raise TypeError("... has no len()") if the new length cannot be calculated without iterating over the dataset. But this currently breaks features because len(self) is called on multiple occasions in StreamingDataset and is assumed to return the actual number of items in the dataset. For example:
litData/src/litdata/streaming/dataset.py
Lines 651 to 654 in 6ceef3e
| if state["num_samples_yielded"] > len(self): | |
| raise ValueError( | |
| "The provided `num_samples_yielded` state is greater than the dataset length. " | |
| f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`." |
Pitch
I have not totally thought this through but one option could be to instruct users to overwrite a different method from __next__ à la transform. This method would have the same logic as __next__ but users wouldn't have to worry about updating internal counters. We would then call this user-defined method inside __next__ and pass the correct sample counter updates to StreamingDataLoader. Kind of how we handle the updates to pass to StreamingDataLoader for CombinedStreamingDataset and ParallelStreamingDataset.
This could become a fairly big enterprise so maybe there is a smarter solution.