Skip to content

Add ParallelStreamingDataset #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

philgzl
Copy link
Contributor

@philgzl philgzl commented May 1, 2025

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #554.

This PR adds a new dataset class ParallelStreamingDataset. As opposed to CombinedStreamingDataset which yields a sample from one of the datasets it wraps at each iteration, ParallelStreamingDataset yields samples from all the datasets it wraps at each iteration.

This allows to truly combine multiple datasets to generate new data on-the-fly (e.g. multi-source audio signals from different types of sources), which is more flexible compared to pre-generating, optimizing and uploading a frozen version of the combined data.

Since ParallelStreamingDataset and CombinedStreamingDataset share common functionality, a base class _BaseStreamingDatasetWrapper was added.

Instead of adopting a solution similar to weights to compensate for the different length of the wrapped datasets, cycling of the wrapped datasets was implemented. Cycling is controlled with the length option. If None, iteration stops as soon as one of the wrapped datasets is exhausted (no cycling). If an integer or float("inf"), the wrapped datasets are cycled until length items are yielded. This might solve #524, as we can wrap a StreamingDataset with ParallelStreamingDataset to cycle it.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Copy link

codecov bot commented May 1, 2025

Codecov Report

Attention: Patch coverage is 95.47325% with 11 lines in your changes missing coverage. Please review.

Project coverage is 79%. Comparing base (96238b6) to head (5131107).

Additional details and impacted files
@@         Coverage Diff          @@
##           main   #576    +/-   ##
====================================
  Coverage    79%    79%            
====================================
  Files        40     42     +2     
  Lines      6111   6292   +181     
====================================
+ Hits       4809   4979   +170     
- Misses     1302   1313    +11     
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@bhimrazy
Copy link
Collaborator

bhimrazy commented May 5, 2025

Hi @philgzl ,
Thanks for creating this PR and adding support for parallel streaming!
Just wanted to share a runtime error I encountered when testing with batch_size >= 3 and num_workers >= 3:

RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty?

from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset, optimize


def optimize_fn_1(index):
    return index


def optimize_fn_2(index):
    return index + 100


if __name__ == "__main__":
    # optimize(
    #     fn=optimize_fn_1,
    #     inputs=list(range(10)),
    #     output_dir="dataset_1",
    #     num_workers=4,
    #     chunk_bytes="64MB",
    # )

    # optimize(
    #     fn=optimize_fn_2,
    #     inputs=list(range(20)),
    #     output_dir="dataset_2",
    #     num_workers=4,
    #     chunk_bytes="64MB",
    # )

    d1 = StreamingDataset(input_dir="dataset_1")
    d2 = StreamingDataset(input_dir="dataset_2")
    dataset = ParallelStreamingDataset([d1, d2], length=32)
    data_loader = StreamingDataLoader(dataset, batch_size=4, num_workers=4, drop_last=True)

    for batch in data_loader:
        print(batch)
  File "/Users/bhimrajyadav/litdata/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 198, in __next__
    samples, _resets = zip(*[self._get_sample(i) for i in range(len(self._datasets))])
                             ^^^^^^^^^^^^^^^^^^^
  File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 224, in _get_sample
    raise RuntimeError("Failed to get sample from dataset after cycling. Is the dataset empty?")
RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty?

I’ll keep reviewing and share further feedback if I spot anything else.

@philgzl
Copy link
Contributor Author

philgzl commented May 5, 2025

I think this is because worker 4 is assigned 0 samples from dataset 1 (dataset length is 10, batch size is 4 and number of workers is 4).

The reasoning behind the RuntimeError was: if dataset raises a StopIteration, then cycle it and try to fetch again. If for some reason it raises a StopIteration again, then raise a RuntimeError. A consequence of this is that the case where some workers are assigned 0 samples is not supported. So maybe we should not raise that RuntimeError and just raise the second StopIteration.

force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.

"""
self._check_datasets(datasets)
Copy link
Collaborator

@deependujha deependujha May 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ParallelSD can also include CombinedSD. Any potential pitfalls with that?

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
        if any(not isinstance(d, (StreamingDataset, CombinedStreamingDataset)) for d in datasets):
            raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would be cool, I can give it a try

@philgzl
Copy link
Contributor Author

philgzl commented May 5, 2025

I thought I would give an example describing what I would like to achieve. Maybe this can help the review or spark ideas to improve the PR.

I have two datasets, and at each iteration, I would like to fetch a sample from each dataset and yield their weighted sum, using random weights.

This PR allows to do this by subclassing ParallelStreamingDataset, with reproducibility and support for StreamingDataLoader:

import itertools
import random

import torch

import litdata as ld
from litdata.utilities.base import __SAMPLES_KEY__
from litdata.utilities.env import _WorkerEnv


class MyParallelDataset(ld.ParallelStreamingDataset):
    def __init__(self, *args, seed=42, **kwargs):
        super().__init__(*args, **kwargs)
        self.seed = seed

    def __iter__(self):
        seed = self._get_seed()
        rng = random.Random(seed)
        for data in super().__iter__():
            if isinstance(data, dict):
                data[__SAMPLES_KEY__] = self._mix(*data[__SAMPLES_KEY__], rng)
            else:
                data = self._mix(*data, rng)
            yield data

    def _mix(self, x, y, rng):
        return rng.random() * x + rng.random() * y

    def _get_seed(self):
        rank = _WorkerEnv.detect().rank
        samples, cycles = self._get_num_samples_yielded()
        samples = [s % len(d) for s, d in zip(samples, self._datasets)]
        return hash((self.seed, rank, *cycles, *samples))


def test(length, batch_size, num_workers, shuffle):
    print(f"length={length}, batch_size={batch_size}, num_workers={num_workers}, shuffle={shuffle}")

    dset = MyParallelDataset([ld.StreamingDataset("data_x"), ld.StreamingDataset("data_y")], length=length)
    dloader = ld.StreamingDataLoader(dset, batch_size=batch_size, num_workers=num_workers)

    for x in dloader:
        pass

    state = dloader.state_dict()

    data = []
    for x in dloader:
        data.append(x)

    dloader.load_state_dict(state)

    for old, new in zip(data, dloader):
        assert torch.equal(old, new)


def fake_data(_):
    return random.random()


def generate_data():
    ld.optimize(
        fn=fake_data,
        inputs=list(range(100)),
        output_dir="data_x",
        chunk_size=1,
        num_workers=4,
    )
    ld.optimize(
        fn=fake_data,
        inputs=list(range(100)),
        output_dir="data_y",
        chunk_size=1,
        num_workers=4,
    )


if __name__ == "__main__":
    # generate_data()

    for length, batch_size, num_workers, shuffle in itertools.product(
        [None, 42, 128],
        [1, 4],
        [0, 4],
        [False, True],
    ):
        test(length, batch_size, num_workers, shuffle)

This works great. However, it is not super user-friendly:

  • need to subclass __iter__ as __next__ is not a method of ParallelStreamingDataset
  • need to handle whether the sample is a dict with __SAMPLES_KEY__ key
  • correctly seeding the internal random number generator for reproducibility is not trivial

Possible solutions:

  • Add a mix method intended to be overriden by users à la collate_fn or transform. This method would be wrapped internally by __next__ or __iter__ and would not have to handle whether the sample is a dict.
  • Provide a ready-to-use random number generator attribute that is correctly seeded at the start of __iter__, which can be optionally used by users inside mix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

How to correctly mix multiple StreamingDataset to create new data?
3 participants