Skip to content

Commit cb3fc81

Browse files
schopra8pre-commit-ci[bot]bhimrazyBorda
authored
Feat: Add per_stream batching method to CombinedStreamingDataset (#438)
* Init Implementation per_stream batching * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address pre-commit issues * fixed bug in combined.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * expose _set_new_dataset_index() to dataloader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * WIP - communicate to worker loop * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * remove print * cleanup print * updated doc string * added types and also update docs * revert: changes from dataloader * update the combined streaming dataset with batching method * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert changes from dataloader * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add test for combined dataset with per-stream batching * remove set_batching_method from CombinedStreamingDataset to simplify batch handling * updaet comment * refactor: replace string literals with BatchingMethod constants for clarity * cleanup * refactor: variables initialization --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 68d23cd commit cb3fc81

File tree

2 files changed

+68
-1
lines changed

2 files changed

+68
-1
lines changed

src/litdata/streaming/combined.py

+43-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import logging
1515
import random
1616
from copy import deepcopy
17-
from typing import Any, Dict, Iterator, List, Optional, Sequence
17+
from typing import Any, Dict, Iterator, List, Literal, Optional, Sequence
1818

1919
from torch.utils.data import IterableDataset
2020

@@ -28,6 +28,14 @@
2828
logger = logging.getLogger("litdata.streaming.combined")
2929

3030

31+
class BatchingMethod:
32+
STRATIFIED = "stratified"
33+
PER_STREAM = "per_stream"
34+
35+
36+
BatchingMethodType = Literal["stratified", "per_stream"]
37+
38+
3139
class CombinedStreamingDataset(IterableDataset):
3240
"""Enables to stream data from multiple StreamingDataset with the sampling ratio of
3341
your choice.
@@ -46,6 +54,7 @@ def __init__(
4654
seed: int = 42,
4755
weights: Optional[Sequence[float]] = None,
4856
iterate_over_all: bool = True,
57+
batching_method: BatchingMethodType = "stratified",
4958
force_override_state_dict: bool = False,
5059
) -> None:
5160
"""Enable to stream data from multiple StreamingDataset with the sampling ratio of your choice.
@@ -56,7 +65,11 @@ def __init__(
5665
weights: The sampling ratio for the datasets
5766
iterate_over_all: When iterate_over_all is True, the combined dataset iterates over all the datasets.
5867
Otherwise, it stops as soon as one raises a StopIteration.
68+
batching_method (str, optional): When batching_method is set to "stratified" (default),
69+
batches will include samples from all datasets. On the other hand, when batching_method is "per_stream",
70+
batches will consist of samples from a single dataset, which is selected randomly.
5971
force_override_state_dict: Boolean flag for allowing local arguments to override a loaded state dict.
72+
6073
"""
6174
self._check_datasets(datasets)
6275

@@ -90,6 +103,7 @@ def __init__(
90103
self._current_epoch = 0
91104
self.num_workers = 1
92105
self.batch_size = 1
106+
self._batching_method: BatchingMethodType = batching_method
93107

94108
def get_len(self, num_workers: int, batch_size: int) -> Optional[int]:
95109
self.num_workers = num_workers
@@ -170,6 +184,8 @@ def __iter__(self) -> Iterator[Any]:
170184
self._weights,
171185
self._use_streaming_dataloader,
172186
num_samples_yielded,
187+
self.batch_size,
188+
self._batching_method,
173189
self._iterate_over_all,
174190
)
175191
return self._iterator
@@ -217,6 +233,8 @@ def __init__(
217233
weights: Sequence[Optional[float]],
218234
use_streaming_dataloader: bool,
219235
num_samples_yielded: Any,
236+
batch_size: int,
237+
batching_method: BatchingMethodType,
220238
iterate_over_all: bool = False,
221239
) -> None:
222240
self._datasets = datasets
@@ -227,6 +245,8 @@ def __init__(
227245
self._weights = deepcopy(weights)
228246
self._rng = random.Random(seed) # noqa: S311
229247
self._iterate_over_all = iterate_over_all
248+
self._batching_method = batching_method
249+
self._batch_size = batch_size
230250
self._is_done = False
231251

232252
if num_samples_yielded is not None:
@@ -238,6 +258,13 @@ def __init__(
238258

239259
self._use_streaming_dataloader = use_streaming_dataloader
240260
self._is_done = False
261+
262+
# Used to track the number of samples yielded in the current batch
263+
# and the current dataset index
264+
# This is used only when batching_method is set to "per_stream"
265+
self._samples_yielded_in_batch = 0
266+
self._cur_dataset_index = -1
267+
241268
logger.debug(
242269
_get_log_msg({"name": "iterating_combined_dataset", "ph": "B", "cname": ChromeTraceColors.LIGHT_BLUE})
243270
)
@@ -272,6 +299,21 @@ def __next__(self) -> Any:
272299
return self._get_sample(self._get_dataset_index())
273300

274301
def _get_dataset_index(self) -> int:
302+
if self._batching_method == BatchingMethod.STRATIFIED:
303+
# For every sample, randomly select a dataset (weighted)
304+
dataset_idx = self._set_new_dataset_index()
305+
elif self._batching_method == BatchingMethod.PER_STREAM:
306+
# For each batch, pick a dataset and stick with it for the whole batch
307+
if self._cur_dataset_index == -1 or self._samples_yielded_in_batch >= self._batch_size:
308+
self._cur_dataset_index = self._set_new_dataset_index()
309+
self._samples_yielded_in_batch = 0
310+
dataset_idx = self._cur_dataset_index
311+
self._samples_yielded_in_batch += 1
312+
else:
313+
raise ValueError(f"Invalid batching method: {self._batching_method}")
314+
return dataset_idx
315+
316+
def _set_new_dataset_index(self) -> int:
275317
# randomly select a dataset index
276318
indexes = [index for index in self._dataset_indexes if index is not None]
277319
weights = [w for w in self._weights if w is not None]

tests/streaming/test_combined.py

+25
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,31 @@ def test_combined_dataset():
287287
assert torch.equal(next(dataloader_iter), torch.Tensor([0, 1]))
288288

289289

290+
@pytest.mark.parametrize("batch_size", [2, 4])
291+
@pytest.mark.parametrize("num_workers", [1, 2])
292+
def test_combined_dataset_with_per_stream_batching(tmpdir, batch_size, num_workers):
293+
num_of_datasets = 2
294+
dataset_ranges = [(0, 10), (10, 20)]
295+
dataset_paths = [str(tmpdir.join(f"dataset_{i}")) for i in range(num_of_datasets)]
296+
for dataset_path, (start, end) in zip(dataset_paths, dataset_ranges):
297+
os.makedirs(dataset_path)
298+
cache = Cache(input_dir=dataset_path, chunk_size=2)
299+
for i in range(start, end):
300+
cache[i] = i
301+
cache.done()
302+
cache.merge()
303+
304+
datasets = [StreamingDataset(input_dir=str(dataset_path)) for dataset_path in dataset_paths]
305+
dataset = CombinedStreamingDataset(datasets=datasets, seed=12345, batching_method="per_stream")
306+
dataloader = StreamingDataLoader(dataset, batch_size=batch_size, num_workers=num_workers, drop_last=True)
307+
308+
for batch in dataloader:
309+
# Ensure that the batch contains items exclusively from a single dataset
310+
assert all(x in range(0, 10) for x in batch) or all(x in range(10, 20) for x in batch), (
311+
f"Batch should contain elements from only one dataset but got {batch}"
312+
)
313+
314+
290315
@pytest.mark.parametrize("batch_size", [1, 2])
291316
def test_combined_dataset_with_dataloader_and_one_worker(batch_size):
292317
dataset1 = SimpleDataset(0, 10)

0 commit comments

Comments
 (0)