Skip to content

Add ParallelStreamingDataset #576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 35 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0c03590
Add ParallelStreamingDataset
philgzl Apr 19, 2025
aec7632
Add tests
philgzl Apr 24, 2025
a8cb189
Make __len__ an abstract method of base dataset wrapper
philgzl Apr 24, 2025
f81c25b
Update tests
philgzl Apr 28, 2025
77e1847
Add num_cycles attribute
philgzl Apr 29, 2025
aeb6e3f
Finish cooking
philgzl May 1, 2025
adc97e6
Try to clean diff and add docs
philgzl May 1, 2025
6e2c4fe
Fix mypy and 3.8 type hint errors
philgzl May 1, 2025
a5bb83a
Fix mypy errors for real this time
philgzl May 1, 2025
376a28a
Merge branch 'main' into parallel-dataset
bhimrazy May 3, 2025
a7d941a
Merge branch 'main' into parallel-dataset
bhimrazy May 5, 2025
b9a21e4
Remove _BaseDatasetWrapperIterator
philgzl May 5, 2025
2a19ff5
Update tests
philgzl May 5, 2025
bb34245
Fix RuntimeError when some workers are assigned 0 samples
philgzl May 5, 2025
5131107
Fix mypy errors
philgzl May 5, 2025
b299b1b
Increase CI timeout from 35 to 45 minutes
deependujha May 7, 2025
216ee0e
Update test_parallel.py to skip tests on macOS in addition to Windows
deependujha May 7, 2025
65d0776
Add transform
philgzl May 8, 2025
1f6563f
Update README.md
philgzl May 8, 2025
6462652
Fix can't pickle local object error
philgzl May 9, 2025
63d0bbe
Update README.md
philgzl May 9, 2025
47aa5c5
Skip more tests on win32 and darwin
philgzl May 9, 2025
2b92d46
Update docstrings
philgzl May 9, 2025
673b716
Add comment in get_len
philgzl May 9, 2025
f49bd19
Replace tmpdir and tmdir_factory with tmp_path and tmp_path_factory
philgzl May 9, 2025
9e06851
Merge branch 'main' into parallel-dataset
deependujha May 11, 2025
696e471
Apply suggestions
philgzl May 18, 2025
f880bbd
Update tests
philgzl May 18, 2025
1f6bebe
Update README.md
philgzl May 18, 2025
3e81d4d
Merge branch 'main' into parallel-dataset
philgzl May 18, 2025
59a52ad
Fix list type hint
philgzl May 18, 2025
ec4d706
Change samples_yieled to samples_yielded
philgzl May 22, 2025
f7c63a9
Skip even more tests to fit in macos CI time limit
philgzl May 22, 2025
b1a821a
Merge branch 'main' into parallel-dataset
philgzl May 22, 2025
67c0269
Merge branch 'main' into parallel-dataset
bhimrazy May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,83 @@ for batch in tqdm(train_dataloader):
```
</details>

<details>
<summary> ✅ Parallel streaming</summary>
&nbsp;

While `CombinedDataset` allows to fetch a sample from one of the datasets it wraps at each iteration, `ParallelStreamingDataset` can be used to fetch a sample from all the wrapped datasets at each iteration:

```python
from litdata import StreamingDataset, ParallelStreamingDataset, StreamingDataLoader
from tqdm import tqdm

parallel_dataset = ParallelStreamingDataset(
[
StreamingDataset(input_dir="input_dir_1"),
StreamingDataset(input_dir="input_dir_2"),
],
)

dataloader = StreamingDataLoader(parallel_dataset)

for batch_1, batch_2 in tqdm(dataloader):
pass
```

This is useful to generate new data on-the-fly using a sample from each dataset. To do so, provide a ``transform`` function to `ParallelStreamingDataset`:

```python
def transform(samples: Tuple[Any]):
sample_1, sample_2 = samples # as many samples as wrapped datasets
return sample_1 + sample_2 # example transformation

parallel_dataset = ParallelStreamingDataset([dset_1, dset_2], transform=transform)

dataloader = StreamingDataLoader(parallel_dataset)

for transformed_batch in tqdm(dataloader):
pass
```

If the transformation requires random number generation, internal random number generators provided by `ParallelStreamingDataset` can be used. These are seeded using the current dataset state at the beginning of each epoch, which allows for reproducible and resumable data transformation. To use them, define a ``transform`` which takes a dictionary of random number generators as its second argument:

```python
def transform(samples: Tuple[Any], rngs: Dict[str, Any]):
sample_1, sample_2 = samples # as many samples as wrapped datasets
rng = rngs["random"] # "random", "numpy" and "torch" keys available
return rng.random() * sample_1 + rng.random() * sample_2 # example transformation

parallel_dataset = ParallelStreamingDataset([dset_1, dset_2], transform=transform)
```
</details>

<details>
<summary> ✅ Cycle datasets</summary>
&nbsp;

`ParallelStreamingDataset` can also be used to cycle a `StreamingDataset`. This allows to dissociate the epoch length from the number of samples in the dataset.

To do so, set the `length` option to the desired number of samples to yield per epoch. If ``length`` is greater than the number of samples in the dataset, the dataset is cycled. At the beginning of a new epoch, the dataset resumes from where it left off at the end of the previous epoch.

```python
from litdata import StreamingDataset, ParallelStreamingDataset, StreamingDataLoader
from tqdm import tqdm

dataset = StreamingDataset(input_dir="input_dir")

cycled_dataset = ParallelStreamingDataset([dataset], length=100)

print(len(cycled_dataset))) # 100

dataloader = StreamingDataLoader(cycled_dataset)

for batch, in tqdm(dataloader):
pass
```

You can even set `length` to `float("inf")` for an infinite dataset!
</details>

<details>
<summary> ✅ Merge datasets</summary>
&nbsp;
Expand Down
2 changes: 2 additions & 0 deletions src/litdata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import TokensLoader
from litdata.streaming.parallel import ParallelStreamingDataset
from litdata.streaming.writer import index_parquet_dataset
from litdata.utilities.breakpoint import breakpoint
from litdata.utilities.hf_dataset import index_hf_dataset
Expand All @@ -28,6 +29,7 @@
"CombinedStreamingDataset",
"StreamingDataLoader",
"TokensLoader",
"ParallelStreamingDataset",
"map",
"optimize",
"walk",
Expand Down
2 changes: 2 additions & 0 deletions src/litdata/streaming/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
from litdata.streaming.dataloader import StreamingDataLoader
from litdata.streaming.dataset import StreamingDataset
from litdata.streaming.item_loader import TokensLoader
from litdata.streaming.parallel import ParallelStreamingDataset

__all__ = [
"Cache",
"StreamingDataset",
"CombinedStreamingDataset",
"StreamingDataLoader",
"TokensLoader",
"ParallelStreamingDataset",
]
78 changes: 7 additions & 71 deletions src/litdata/streaming/combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@
from copy import deepcopy
from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence

from torch.utils.data import IterableDataset

from litdata.debugger import ChromeTraceColors, _get_log_msg
from litdata.streaming.dataset import StreamingDataset
from litdata.utilities.base import (
__NUM_SAMPLES_YIELDED_KEY__,
__SAMPLES_KEY__,
_BaseStreamingDatasetWrapper,
)
from litdata.utilities.env import _WorkerEnv

__NUM_SAMPLES_YIELDED_KEY__ = "__NUM_SAMPLES_YIELDED__"
__SAMPLES_KEY__ = "__SAMPLES__"

logger = logging.getLogger("litdata.streaming.combined")


Expand All @@ -36,7 +36,7 @@ class BatchingMethod:
BatchingMethodType = Literal["stratified", "per_stream"]


class CombinedStreamingDataset(IterableDataset):
class CombinedStreamingDataset(_BaseStreamingDatasetWrapper):
"""Enables to stream data from multiple StreamingDataset with the sampling ratio of
your choice.

Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(

self._iterator: Optional[_CombinedDatasetIterator] = None
self._use_streaming_dataloader = False
self._num_samples_yielded: Optional[List[int]] = None
self._num_samples_yielded: Optional[Dict[int, List[int]]] = None
self._current_epoch = 0
self.num_workers = 1
self.batch_size = 1
Expand All @@ -119,11 +119,6 @@ def __len__(self) -> Optional[int]:
def _get_total_length(self) -> int:
return sum(self._get_len(d) for d in self._datasets)

def _get_len(self, d: Any) -> int:
if isinstance(d, StreamingDataset):
return d.get_len(self.num_workers, self.batch_size)
return len(d)

def set_epoch(self, current_epoch: int) -> None:
"""Set the current epoch to the datasets on epoch starts.

Expand All @@ -134,40 +129,6 @@ def set_epoch(self, current_epoch: int) -> None:
for dataset in self._datasets:
dataset.set_epoch(current_epoch)

def set_shuffle(self, shuffle: bool) -> None:
"""Set the current shuffle to the datasets."""
for dataset in self._datasets:
dataset.set_shuffle(shuffle)

def set_batch_size(self, batch_size: int) -> None:
"""Set the current batch size to the datasets."""
self.batch_size = batch_size
for dataset in self._datasets:
dataset.set_batch_size(batch_size)

def set_num_workers(self, num_workers: int) -> None:
"""Set the current number of workers to the datasets."""
for dataset in self._datasets:
dataset.set_num_workers(num_workers)

def set_drop_last(self, drop_last: bool) -> None:
"""Set the current drop_last to the datasets."""
for dataset in self._datasets:
dataset.set_drop_last(drop_last)

def reset_state_dict(self) -> None:
"""Reset the state of the dataset."""
for dataset in self._datasets:
dataset.reset_state_dict()

def _check_datasets(self, datasets: List[StreamingDataset]) -> None:
if any(not isinstance(d, StreamingDataset) for d in datasets):
raise RuntimeError("The provided datasets should be instances of the StreamingDataset.")

def _set_use_streaming_dataloader(self, use_streaming_dataloader: bool) -> None:
# Used to prevent returning num_samples_yielded when using PyTorch DataLoader
self._use_streaming_dataloader = use_streaming_dataloader

def __iter__(self) -> Iterator[Any]:
assert self._weights

Expand Down Expand Up @@ -199,31 +160,6 @@ def state_dict(
return _state_dict(self._datasets, num_samples_yielded, num_workers, batch_size)
return self._iterator.state_dict(num_workers, batch_size)

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
if not state_dict:
return

if len(state_dict["dataset"]) != len(self._datasets):
if not self._force_override_state_dict:
raise RuntimeError(
f"The provided state doesn't match the current number of datasets: {self._datasets}."
)
if len(state_dict["dataset"]) > len(self._datasets):
raise RuntimeError(
"Currently it's only possible to add datasets to the end of the dataset list when overriding state"
)

for dataset_idx, dataset in enumerate(self._datasets):
if str(dataset_idx) in state_dict["dataset"]:
dataset.load_state_dict(state_dict["dataset"][str(dataset_idx)])

elif not self._force_override_state_dict:
raise RuntimeError(f"The provided state doesn't contain the index {dataset_idx}.")

# Used to iterate over the sampler to avoid sampling the same samples
if self._use_streaming_dataloader:
self._num_samples_yielded = state_dict["num_samples_yielded"]


class _CombinedDatasetIterator(Iterator):
def __init__(
Expand Down
Loading
Loading