-
Notifications
You must be signed in to change notification settings - Fork 65
Add ParallelStreamingDataset
#576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #576 +/- ##
====================================
Coverage 79% 79%
====================================
Files 40 42 +2
Lines 6111 6292 +181
====================================
+ Hits 4809 4979 +170
- Misses 1302 1313 +11 🚀 New features to boost your workflow:
|
Hi @philgzl , RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty?
from litdata import ParallelStreamingDataset, StreamingDataLoader, StreamingDataset, optimize
def optimize_fn_1(index):
return index
def optimize_fn_2(index):
return index + 100
if __name__ == "__main__":
# optimize(
# fn=optimize_fn_1,
# inputs=list(range(10)),
# output_dir="dataset_1",
# num_workers=4,
# chunk_bytes="64MB",
# )
# optimize(
# fn=optimize_fn_2,
# inputs=list(range(20)),
# output_dir="dataset_2",
# num_workers=4,
# chunk_bytes="64MB",
# )
d1 = StreamingDataset(input_dir="dataset_1")
d2 = StreamingDataset(input_dir="dataset_2")
dataset = ParallelStreamingDataset([d1, d2], length=32)
data_loader = StreamingDataLoader(dataset, batch_size=4, num_workers=4, drop_last=True)
for batch in data_loader:
print(batch) File "/Users/bhimrajyadav/litdata/.venv/lib/python3.12/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 198, in __next__
samples, _resets = zip(*[self._get_sample(i) for i in range(len(self._datasets))])
^^^^^^^^^^^^^^^^^^^
File "/Users/bhimrajyadav/litdata/src/litdata/streaming/parallel.py", line 224, in _get_sample
raise RuntimeError("Failed to get sample from dataset after cycling. Is the dataset empty?")
RuntimeError: Failed to get sample from dataset after cycling. Is the dataset empty? I’ll keep reviewing and share further feedback if I spot anything else. |
I think this is because worker 4 is assigned 0 samples from dataset 1 (dataset length is 10, batch size is 4 and number of workers is 4). The reasoning behind the RuntimeError was: if dataset raises a StopIteration, then cycle it and try to fetch again. If for some reason it raises a StopIteration again, then raise a RuntimeError. A consequence of this is that the case where some workers are assigned 0 samples is not supported. So maybe we should not raise that RuntimeError and just raise the second StopIteration. |
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict. | ||
|
||
""" | ||
self._check_datasets(datasets) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think ParallelSD can also include CombinedSD. Any potential pitfalls with that?
def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, (StreamingDataset, CombinedStreamingDataset)) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would be cool, I can give it a try
I thought I would give an example describing what I would like to achieve. Maybe this can help the review or spark ideas to improve the PR. I have two datasets, and at each iteration, I would like to fetch a sample from each dataset and yield their weighted sum, using random weights. This PR allows to do this by subclassing import itertools
import random
import torch
import litdata as ld
from litdata.utilities.base import __SAMPLES_KEY__
from litdata.utilities.env import _WorkerEnv
class MyParallelDataset(ld.ParallelStreamingDataset):
def __init__(self, *args, seed=42, **kwargs):
super().__init__(*args, **kwargs)
self.seed = seed
def __iter__(self):
seed = self._get_seed()
rng = random.Random(seed)
for data in super().__iter__():
if isinstance(data, dict):
data[__SAMPLES_KEY__] = self._mix(*data[__SAMPLES_KEY__], rng)
else:
data = self._mix(*data, rng)
yield data
def _mix(self, x, y, rng):
return rng.random() * x + rng.random() * y
def _get_seed(self):
rank = _WorkerEnv.detect().rank
samples, cycles = self._get_num_samples_yielded()
samples = [s % len(d) for s, d in zip(samples, self._datasets)]
return hash((self.seed, rank, *cycles, *samples))
def test(length, batch_size, num_workers, shuffle):
print(f"length={length}, batch_size={batch_size}, num_workers={num_workers}, shuffle={shuffle}")
dset = MyParallelDataset([ld.StreamingDataset("data_x"), ld.StreamingDataset("data_y")], length=length)
dloader = ld.StreamingDataLoader(dset, batch_size=batch_size, num_workers=num_workers)
for x in dloader:
pass
state = dloader.state_dict()
data = []
for x in dloader:
data.append(x)
dloader.load_state_dict(state)
for old, new in zip(data, dloader):
assert torch.equal(old, new)
def fake_data(_):
return random.random()
def generate_data():
ld.optimize(
fn=fake_data,
inputs=list(range(100)),
output_dir="data_x",
chunk_size=1,
num_workers=4,
)
ld.optimize(
fn=fake_data,
inputs=list(range(100)),
output_dir="data_y",
chunk_size=1,
num_workers=4,
)
if __name__ == "__main__":
# generate_data()
for length, batch_size, num_workers, shuffle in itertools.product(
[None, 42, 128],
[1, 4],
[0, 4],
[False, True],
):
test(length, batch_size, num_workers, shuffle) This works great. However, it is not super user-friendly:
Possible solutions:
|
Before submitting
What does this PR do?
Fixes #554.
This PR adds a new dataset class
ParallelStreamingDataset
. As opposed toCombinedStreamingDataset
which yields a sample from one of the datasets it wraps at each iteration,ParallelStreamingDataset
yields samples from all the datasets it wraps at each iteration.This allows to truly combine multiple datasets to generate new data on-the-fly (e.g. multi-source audio signals from different types of sources), which is more flexible compared to pre-generating, optimizing and uploading a frozen version of the combined data.
Since
ParallelStreamingDataset
andCombinedStreamingDataset
share common functionality, a base class_BaseStreamingDatasetWrapper
was added.Instead of adopting a solution similar to
weights
to compensate for the different length of the wrapped datasets, cycling of the wrapped datasets was implemented. Cycling is controlled with thelength
option. IfNone
, iteration stops as soon as one of the wrapped datasets is exhausted (no cycling). If an integer orfloat("inf")
, the wrapped datasets are cycled untillength
items are yielded. This might solve #524, as we can wrap aStreamingDataset
withParallelStreamingDataset
to cycle it.PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in GitHub issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃