Skip to content

Support for subclasses of StreamingDataset with different length #642

@philgzl

Description

@philgzl

🚀 Feature

Support for subclasses of StreamingDataset which yield a different number of samples compared to the original dataset.

Motivation

There are cases where the user might want to subclass StreamingDataset to create a dataset which yields a different number of samples compared to the default implementation. For example in #552, I read samples with variable length and queue them. Once the total length is long enough, I pop from the queue and yield until the total length is not long enough, in which case I load next sample. Another example could be skipping some samples, e.g. if they have bad quality.

While the example in #552 works at first glance, it actually breaks key features because the internal sample counters are increased for each sample yielded, regardless of how many items were actually read from the dataset. For example in StreamingDataLoader:

self._num_samples_yielded_streaming += self.batch_size

Or in ParallelStreamingDataset:

self._num_samples_yielded[i] = 1 if _reset else self._num_samples_yielded[i] + 1

This naturally messes up state saving and loading as the internal dataset is resumed at an offset which does not match the number of samples actually read from the dataset.

Another issue is poor support for overriding __len__. Users should be allowed to update __len__ to return the new number of samples the dataset yields, or at the very least to raise TypeError("... has no len()") if the new length cannot be calculated without iterating over the dataset. But this currently breaks features because len(self) is called on multiple occasions in StreamingDataset and is assumed to return the actual number of items in the dataset. For example:

if state["num_samples_yielded"] > len(self):
raise ValueError(
"The provided `num_samples_yielded` state is greater than the dataset length. "
f"Found `{state['num_samples_yielded']}` instead of `{len(self)}`."

Pitch

I have not totally thought this through but one option could be to instruct users to overwrite a different method from __next__ à la transform. This method would have the same logic as __next__ but users wouldn't have to worry about updating internal counters. We would then call this user-defined method inside __next__ and pass the correct sample counter updates to StreamingDataLoader. Kind of how we handle the updates to pass to StreamingDataLoader for CombinedStreamingDataset and ParallelStreamingDataset.

This could become a fairly big enterprise so maybe there is a smarter solution.

Alternatives

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions