diff --git a/docs/source/package_reference/main_classes.mdx b/docs/source/package_reference/main_classes.mdx index 185bde10d72..62dc9127d4b 100644 --- a/docs/source/package_reference/main_classes.mdx +++ b/docs/source/package_reference/main_classes.mdx @@ -52,6 +52,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table. - take - train_test_split - shard + - repeat - to_tf_dataset - push_to_hub - save_to_disk @@ -172,6 +173,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth - skip - take - shard + - repeat - load_state_dict - state_dict - info diff --git a/src/datasets/arrow_dataset.py b/src/datasets/arrow_dataset.py index f272b3ec24b..3f1cca6e69c 100644 --- a/src/datasets/arrow_dataset.py +++ b/src/datasets/arrow_dataset.py @@ -4078,6 +4078,38 @@ def skip(self, n: int) -> "Dataset": """ return self.select(range(n, len(self))) + def repeat(self, num_times: int) -> "Dataset": + """ + Create a new [`Dataset`] that repeats the underlying dataset `num_times` times. + + Like itertools.repeat, repeating once just returns the full dataset. + + Args: + num_times (`int`): + Number of times to repeat the dataset. + + Example: + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("rotten_tomatoes", split="train") + >>> ds = ds.take(2).repeat(2) + >>> list(ds) + [{'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}, + {'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}] + ``` + """ + if num_times is None: + raise ValueError("Map style datasets do not support indefinite repetition.") + return _concatenate_map_style_datasets([self] * num_times) if num_times > 0 else self.select([]) + def take(self, n: int) -> "Dataset": """ Create a new [`Dataset`] with only the first `n` elements. diff --git a/src/datasets/iterable_dataset.py b/src/datasets/iterable_dataset.py index f2d47ff64b9..627a6c03824 100644 --- a/src/datasets/iterable_dataset.py +++ b/src/datasets/iterable_dataset.py @@ -1622,6 +1622,54 @@ def num_shards(self) -> int: return self.ex_iterable.num_shards +class RepeatExamplesIterable(_BaseExamplesIterable): + """ + Iterable that repeats the underlying iterable a given number of times. + """ + + def __init__( + self, + ex_iterable: _BaseExamplesIterable, + num_times: Optional[int], + ): + super().__init__() + self.ex_iterable = ex_iterable + self.num_times = num_times + + def _init_state_dict(self) -> dict: + self._state_dict = { + "repeat_index": 0, + "ex_iterable": self.ex_iterable._init_state_dict(), + } + return self._state_dict + + def __iter__(self): + repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0 + while True: + if self.num_times is not None and repeat_index >= max(self.num_times, 0): + break + yield from self.ex_iterable + repeat_index += 1 + if self._state_dict: + self._state_dict["repeat_index"] = repeat_index + self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict() + + def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable": + """Shuffle the underlying iterable, then repeat.""" + return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times) + + def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable": + """Shard, then repeat shards.""" + return RepeatExamplesIterable( + self.ex_iterable.shard_data_sources(worker_id, num_workers), + num_times=self.num_times, + ) + + @property + def n_shards(self) -> int: + return self.ex_iterable.n_shards + + class TakeExamplesIterable(_BaseExamplesIterable): def __init__( self, @@ -2762,6 +2810,49 @@ def skip(self, n: int) -> "IterableDataset": token_per_repo_id=self._token_per_repo_id, ) + def repeat(self, num_times: Optional[int]) -> "IterableDataset": + """ + Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times. + + N.B. The effect of calling shuffle after repeat depends significantly on buffer size. + With buffer_size 1, duplicate data is never seen in the same iteration, even after shuffling: + ds.repeat(n).shuffle(seed=42, buffer_size=1) is equivalent to ds.shuffle(seed=42, buffer_size=1).repeat(n), + and only shuffles shard orders within each iteration. + With buffer size >= (num samples in the dataset * num_times), we get full shuffling of the repeated data, i.e. we can observe duplicates in + the same iteration. + + Args: + num_times (`int`) or (`None`): + Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely. + + Example: + ```py + >>> from datasets import load_dataset + >>> ds = load_dataset("rotten_tomatoes", split="train") + >>> ds = ds.take(2).repeat(2) + >>> list(ds) + [{'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}, + {'label': 1, + 'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'}, + {'label': 1, + 'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'}, + {'label': 1, 'text': 'effective but too-tepid biopic'}] + ``` + """ + return IterableDataset( + ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times), + info=self._info, + split=self._split, + formatting=self._formatting, + shuffling=copy.deepcopy(self._shuffling), + distributed=copy.deepcopy(self._distributed), + token_per_repo_id=self._token_per_repo_id, + ) + def take(self, n: int) -> "IterableDataset": """ Create a new [`IterableDataset`] with only the first `n` elements. diff --git a/tests/test_arrow_dataset.py b/tests/test_arrow_dataset.py index 6cf8898ce67..766174b15a1 100644 --- a/tests/test_arrow_dataset.py +++ b/tests/test_arrow_dataset.py @@ -869,6 +869,32 @@ def test_concatenate_pickle(self, in_memory): self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1") del dset1, dset2, dset3 + def test_repeat(self, in_memory): + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(3) + column_values_dict = {col: dset[col] for col in dset.column_names} + for col, single_values in column_values_dict.items(): + self.assertListEqual(repeated_dset[col], single_values * 3) + del repeated_dset + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + with pytest.raises(ValueError): + dset.repeat(None) + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(0) + self.assertEqual(len(repeated_dset), 0) + del repeated_dset + + with tempfile.TemporaryDirectory() as tmp_dir: + with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset: + repeated_dset = dset.repeat(-1) + self.assertEqual(len(repeated_dset), 0) + del repeated_dset + def test_flatten(self, in_memory): with tempfile.TemporaryDirectory() as tmp_dir: with Dataset.from_dict( diff --git a/tests/test_iterable_dataset.py b/tests/test_iterable_dataset.py index f3f7ee3106f..8a972ec9cd3 100644 --- a/tests/test_iterable_dataset.py +++ b/tests/test_iterable_dataset.py @@ -33,6 +33,7 @@ MappedExamplesIterable, RandomlyCyclingMultiSourcesExamplesIterable, RebatchedArrowExamplesIterable, + RepeatExamplesIterable, SelectColumnsIterable, ShuffledDataSourcesArrowExamplesIterable, ShuffledDataSourcesExamplesIterable, @@ -1165,6 +1166,28 @@ def test_take_examples_iterable(): assert_load_state_dict_resumes_iteration(take_ex_iterable) +@pytest.mark.parametrize( + "n, num_times", + [ + (3, None), + (3, 3), + (3, 0), + ], +) +def test_repeat_examples_iterable(n, num_times): + base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n}) + ex_iterable = RepeatExamplesIterable(base_ex_iterable, num_times=num_times) + all_examples = [x for _, x in generate_examples_fn(n=n)] + if num_times is not None: + expected = all_examples * max(num_times, 0) + assert [x for _, x in ex_iterable] == expected + else: + max_iters = 135 + iterator = iter(ex_iterable) + for i in range(max_iters): + assert next(iterator)[1] == all_examples[i % len(all_examples)], f"iteration {i} failed," + + def test_vertically_concatenated_examples_iterable(): ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10}) ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5}) @@ -1735,6 +1758,14 @@ def test_iterable_dataset_take(dataset: IterableDataset, n): assert list(take_dataset) == list(dataset)[:n] +@pytest.mark.parametrize("n", [0, 2]) +def test_iterable_dataset_repeat(dataset: IterableDataset, n): + repeat_dataset = dataset.repeat(n) + assert isinstance(repeat_dataset._ex_iterable, RepeatExamplesIterable) + assert repeat_dataset._ex_iterable.num_times == n + assert list(repeat_dataset) == list(dataset) * n + + def test_iterable_dataset_shard(): num_examples = 20 num_shards = 5