-
Notifications
You must be signed in to change notification settings - Fork 67
Add ParallelStreamingDataset
#576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #576 +/- ##
====================================
+ Coverage 79% 80% +1%
====================================
Files 41 43 +2
Lines 6143 6381 +238
====================================
+ Hits 4847 5082 +235
- Misses 1296 1299 +3 🚀 New features to boost your workflow:
|
Hi @philgzl , RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty?
from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset, optimize
def optimize_fn_1(index):
return index
def optimize_fn_2(index):
return index + 100
if __name__ == "__main__":
# optimize(
# fn=optimize_fn_1,
# inputs=list(range(10)),
# output_dir="dataset_1",
# num_workers=4,
# chunk_bytes="64MB",
# )
# optimize(
# fn=optimize_fn_2,
# inputs=list(range(20)),
# output_dir="dataset_2",
# num_workers=4,
# chunk_bytes="64MB",
# )
d1 = StreamingDataset(input_dir="dataset_1")
d2 = StreamingDataset(input_dir="dataset_2")
dataset = ParallelStreamingDataset([d1, d2], length=32)
data_loader = StreamingDataLoader(dataset, batch_size=4, num_workers=4, drop_last=True)
for batch in data_loader:
print(batch) File "/Users/bhimrajyadav/litdata/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 198, in __next__
samples, _resets = zip(*[self._get_sample(i) for i in range(len(self._datasets))])
^^^^^^^^^^^^^^^^^^^
File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 224, in _get_sample
raise RuntimeError("Failed to get sample from dataset after cycling. Is the dataset empty?")
RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty? I’ll keep reviewing and share further feedback if I spot anything else. |
I think this is because worker 4 is assigned 0 samples from dataset 1 (dataset length is 10, batch size is 4 and number of workers is 4). The reasoning behind the RuntimeError was: if dataset raises a StopIteration, then cycle it and try to fetch again. If for some reason it raises a StopIteration again, then raise a RuntimeError. A consequence of this is that the case where some workers are assigned 0 samples is not supported. So maybe we should not raise that RuntimeError and just raise the second StopIteration. |
I thought I would give an example describing what I would like to achieve. Maybe this can help the review or spark ideas to improve the PR. I have two datasets, and at each iteration, I would like to fetch a sample from each dataset and yield their weighted sum, using random weights. This PR allows to do this by subclassing import itertools
import random
import torch
import litdata as ld
from litdata.utilities.base import __SAMPLES_KEY__
from litdata.utilities.env import _WorkerEnv
class MyParallelDataset(ld.ParallelStreamingDataset):
def __init__(self, *args, seed=42, **kwargs):
super().__init__(*args, **kwargs)
self.seed = seed
def __iter__(self):
seed = self._get_seed()
rng = random.Random(seed)
for data in super().__iter__():
if isinstance(data, dict):
data[__SAMPLES_KEY__] = self._mix(*data[__SAMPLES_KEY__], rng)
else:
data = self._mix(*data, rng)
yield data
def _mix(self, x, y, rng):
return rng.random() * x + rng.random() * y
def _get_seed(self):
rank = _WorkerEnv.detect().rank
samples, cycles = self._get_num_samples_yielded()
samples = [s % len(d) for s, d in zip(samples, self._datasets)]
return hash((self.seed, rank, *cycles, *samples))
def test(length, batch_size, num_workers, shuffle):
print(f"length={length}, batch_size={batch_size}, num_workers={num_workers}, shuffle={shuffle}")
dset = MyParallelDataset([ld.StreamingDataset("data_x"), ld.StreamingDataset("data_y")], length=length)
dloader = ld.StreamingDataLoader(dset, batch_size=batch_size, num_workers=num_workers)
for x in dloader:
pass
state = dloader.state_dict()
data = []
for x in dloader:
data.append(x)
dloader.load_state_dict(state)
for old, new in zip(data, dloader):
assert torch.equal(old, new)
def fake_data(_):
return random.random()
def generate_data():
ld.optimize(
fn=fake_data,
inputs=list(range(100)),
output_dir="data_x",
chunk_size=1,
num_workers=4,
)
ld.optimize(
fn=fake_data,
inputs=list(range(100)),
output_dir="data_y",
chunk_size=1,
num_workers=4,
)
if __name__ == "__main__":
# generate_data()
for length, batch_size, num_workers, shuffle in itertools.product(
[None, 42, 128],
[1, 4],
[0, 4],
[False, True],
):
test(length, batch_size, num_workers, shuffle) This works great. However, it is not super user-friendly:
Possible solutions:
|
Hi! For the reproducibility part, I like the idea of having a Personally, I feel |
I think with Agree I might need help to fix the tests on macOS 👀. |
For the macOS test, I believe it will require a rerun by the admin to allow it to run beyond 35 minutes, as it's currently exceeding the limit. @philgzl, you can ignore it for now unless an error is reported in the test. |
oh sorry, increased it to 45 as tests have increased and to be on the safer side. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool contribution. 🎉
Added some minor comments from my side, and once you add transform
method support, it'll be good to merge imo.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces a new dataset class, ParallelStreamingDataset, which enables combining multiple streaming datasets in parallel by yielding a sample from each wrapped dataset on every iteration. Key changes include:
- Implementation of ParallelStreamingDataset in src/litdata/streaming/parallel.py.
- Updates in tests (tests/streaming/* and tests/conftest.py) to integrate and verify the new dataset functionality.
- Adjustments in StreamingDataLoader and documentation (README.md) to support the new API.
Reviewed Changes
Copilot reviewed 12 out of 12 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
tests/streaming/test_dataloader.py | Updated assertions and imports for ParallelStreamingDataset. |
tests/conftest.py | Added the parallel_dataset fixture for testing cycling behavior. |
src/litdata/utilities/base.py | Updated base class to include common state handling. |
src/litdata/streaming/parallel.py | New implementation of ParallelStreamingDataset and its iterator. |
src/litdata/streaming/dataloader.py | Adjusted state dict logic and variable names to support parallel datasets. |
src/litdata/streaming/combined.py | Modified CombinedStreamingDataset to subclass the common base. |
README.md | Documentation updated with examples for parallel streaming. |
.github/workflows/ci-testing.yml | Increased timeout minutes. |
Comments suppressed due to low confidence (2)
src/litdata/streaming/parallel.py:303
- [nitpick] The equality assertion checking that all datasets yield the same modulo count might be too strict if the wrapped datasets have different lengths. Consider adding a comment or revisiting this condition to clarify the expected behavior when datasets differ in size.
assert all((dset_lengths[i] * self._num_cycles[i] + self._num_samples_yielded[i]) % self._length == self._count for i in range(1, len(dset_lengths)))
src/litdata/streaming/dataloader.py:646
- [nitpick] The renaming from '_num_samples_yielded_combined' to '_num_samples_yielded_wrapper' is effective, but please ensure that this consistent naming is documented across the codebase to avoid confusion when debugging state restoration.
self._num_samples_yielded_wrapper = {}
or self.dataset._length is None | ||
or self.current_epoch == 0 | ||
): | ||
self._latest_worker_idx = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to wrap those blocks of code into method, so they can be tested and extended with more ease.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very cool contribution @philgzl! It would be cool to showcase this in a multimodal example use case.
Before submitting
What does this PR do?
Fixes #554.
This PR adds a new dataset class
ParallelStreamingDataset
. As opposed toCombinedStreamingDataset
which yields a sample from one of the datasets it wraps at each iteration,ParallelStreamingDataset
yields samples from all the datasets it wraps at each iteration.This allows to truly combine multiple datasets to generate new data on-the-fly (e.g. multi-source audio signals from different types of sources), which is more flexible compared to pre-generating, optimizing and uploading a frozen version of the combined data.
Since
ParallelStreamingDataset
andCombinedStreamingDataset
share common functionality, a base class_BaseStreamingDatasetWrapper
was added.Instead of adopting a solution similar to
weights
to compensate for the different length of the wrapped datasets, cycling of the wrapped datasets was implemented. Cycling is controlled with thelength
option. IfNone
, iteration stops as soon as one of the wrapped datasets is exhausted (no cycling). If an integer orfloat("inf")
, the wrapped datasets are cycled untillength
items are yielded. This might solve #524, as we can wrap aStreamingDataset
withParallelStreamingDataset
to cycle it.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃