Skip to content

Commit 3a4e74a

Browse files
alex-hhlhoestq
andauthored
Add repeat method to datasets (#7198)
* implement repeat method for iterable dataset * implement repeat method for map-style dataset * fix iterable dataset repeat * address pr comments * add test case for map-style dataset * add test cases for iterable datasets * fix code formatting * Update test_arrow_dataset.py * Update test_arrow_dataset.py * Update main_classes.mdx --------- Co-authored-by: Quentin Lhoest <[email protected]>
1 parent 6fa8fb3 commit 3a4e74a

File tree

5 files changed

+182
-0
lines changed

5 files changed

+182
-0
lines changed

docs/source/package_reference/main_classes.mdx

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ The base class [`Dataset`] implements a Dataset backed by an Apache Arrow table.
5252
- take
5353
- train_test_split
5454
- shard
55+
- repeat
5556
- to_tf_dataset
5657
- push_to_hub
5758
- save_to_disk
@@ -172,6 +173,7 @@ The base class [`IterableDataset`] implements an iterable Dataset backed by pyth
172173
- skip
173174
- take
174175
- shard
176+
- repeat
175177
- load_state_dict
176178
- state_dict
177179
- info

src/datasets/arrow_dataset.py

+32
Original file line numberDiff line numberDiff line change
@@ -4078,6 +4078,38 @@ def skip(self, n: int) -> "Dataset":
40784078
"""
40794079
return self.select(range(n, len(self)))
40804080

4081+
def repeat(self, num_times: int) -> "Dataset":
4082+
"""
4083+
Create a new [`Dataset`] that repeats the underlying dataset `num_times` times.
4084+
4085+
Like itertools.repeat, repeating once just returns the full dataset.
4086+
4087+
Args:
4088+
num_times (`int`):
4089+
Number of times to repeat the dataset.
4090+
4091+
Example:
4092+
```py
4093+
>>> from datasets import load_dataset
4094+
>>> ds = load_dataset("rotten_tomatoes", split="train")
4095+
>>> ds = ds.take(2).repeat(2)
4096+
>>> list(ds)
4097+
[{'label': 1,
4098+
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
4099+
{'label': 1,
4100+
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
4101+
{'label': 1, 'text': 'effective but too-tepid biopic'},
4102+
{'label': 1,
4103+
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
4104+
{'label': 1,
4105+
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
4106+
{'label': 1, 'text': 'effective but too-tepid biopic'}]
4107+
```
4108+
"""
4109+
if num_times is None:
4110+
raise ValueError("Map style datasets do not support indefinite repetition.")
4111+
return _concatenate_map_style_datasets([self] * num_times) if num_times > 0 else self.select([])
4112+
40814113
def take(self, n: int) -> "Dataset":
40824114
"""
40834115
Create a new [`Dataset`] with only the first `n` elements.

src/datasets/iterable_dataset.py

+91
Original file line numberDiff line numberDiff line change
@@ -1660,6 +1660,54 @@ def num_shards(self) -> int:
16601660
return self.ex_iterable.num_shards
16611661

16621662

1663+
class RepeatExamplesIterable(_BaseExamplesIterable):
1664+
"""
1665+
Iterable that repeats the underlying iterable a given number of times.
1666+
"""
1667+
1668+
def __init__(
1669+
self,
1670+
ex_iterable: _BaseExamplesIterable,
1671+
num_times: Optional[int],
1672+
):
1673+
super().__init__()
1674+
self.ex_iterable = ex_iterable
1675+
self.num_times = num_times
1676+
1677+
def _init_state_dict(self) -> dict:
1678+
self._state_dict = {
1679+
"repeat_index": 0,
1680+
"ex_iterable": self.ex_iterable._init_state_dict(),
1681+
}
1682+
return self._state_dict
1683+
1684+
def __iter__(self):
1685+
repeat_index = self._state_dict["repeat_index"] if self._state_dict else 0
1686+
while True:
1687+
if self.num_times is not None and repeat_index >= max(self.num_times, 0):
1688+
break
1689+
yield from self.ex_iterable
1690+
repeat_index += 1
1691+
if self._state_dict:
1692+
self._state_dict["repeat_index"] = repeat_index
1693+
self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict()
1694+
1695+
def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable":
1696+
"""Shuffle the underlying iterable, then repeat."""
1697+
return RepeatExamplesIterable(self.ex_iterable.shuffle_data_sources(generator), num_times=self.num_times)
1698+
1699+
def shard_data_sources(self, worker_id: int, num_workers: int) -> "RepeatExamplesIterable":
1700+
"""Shard, then repeat shards."""
1701+
return RepeatExamplesIterable(
1702+
self.ex_iterable.shard_data_sources(worker_id, num_workers),
1703+
num_times=self.num_times,
1704+
)
1705+
1706+
@property
1707+
def n_shards(self) -> int:
1708+
return self.ex_iterable.n_shards
1709+
1710+
16631711
class TakeExamplesIterable(_BaseExamplesIterable):
16641712
def __init__(
16651713
self,
@@ -2801,6 +2849,49 @@ def skip(self, n: int) -> "IterableDataset":
28012849
token_per_repo_id=self._token_per_repo_id,
28022850
)
28032851

2852+
def repeat(self, num_times: Optional[int]) -> "IterableDataset":
2853+
"""
2854+
Create a new [`IterableDataset`] that repeats the underlying dataset `num_times` times.
2855+
2856+
N.B. The effect of calling shuffle after repeat depends significantly on buffer size.
2857+
With buffer_size 1, duplicate data is never seen in the same iteration, even after shuffling:
2858+
ds.repeat(n).shuffle(seed=42, buffer_size=1) is equivalent to ds.shuffle(seed=42, buffer_size=1).repeat(n),
2859+
and only shuffles shard orders within each iteration.
2860+
With buffer size >= (num samples in the dataset * num_times), we get full shuffling of the repeated data, i.e. we can observe duplicates in
2861+
the same iteration.
2862+
2863+
Args:
2864+
num_times (`int`) or (`None`):
2865+
Number of times to repeat the dataset. If `None`, the dataset will be repeated indefinitely.
2866+
2867+
Example:
2868+
```py
2869+
>>> from datasets import load_dataset
2870+
>>> ds = load_dataset("rotten_tomatoes", split="train")
2871+
>>> ds = ds.take(2).repeat(2)
2872+
>>> list(ds)
2873+
[{'label': 1,
2874+
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
2875+
{'label': 1,
2876+
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
2877+
{'label': 1, 'text': 'effective but too-tepid biopic'},
2878+
{'label': 1,
2879+
'text': 'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'},
2880+
{'label': 1,
2881+
'text': 'the gorgeously elaborate continuation of " the lord of the rings " trilogy is so huge that a column of words cannot adequately describe co-writer/director peter jackson\'s expanded vision of j . r . r . tolkien\'s middle-earth .'},
2882+
{'label': 1, 'text': 'effective but too-tepid biopic'}]
2883+
```
2884+
"""
2885+
return IterableDataset(
2886+
ex_iterable=RepeatExamplesIterable(self._ex_iterable, num_times=num_times),
2887+
info=self._info,
2888+
split=self._split,
2889+
formatting=self._formatting,
2890+
shuffling=copy.deepcopy(self._shuffling),
2891+
distributed=copy.deepcopy(self._distributed),
2892+
token_per_repo_id=self._token_per_repo_id,
2893+
)
2894+
28042895
def take(self, n: int) -> "IterableDataset":
28052896
"""
28062897
Create a new [`IterableDataset`] with only the first `n` elements.

tests/test_arrow_dataset.py

+26
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,32 @@ def test_concatenate_pickle(self, in_memory):
869869
self.assertEqual(dset_concat.info.description, "Dataset2\n\nDataset1")
870870
del dset1, dset2, dset3
871871

872+
def test_repeat(self, in_memory):
873+
with tempfile.TemporaryDirectory() as tmp_dir:
874+
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
875+
repeated_dset = dset.repeat(3)
876+
column_values_dict = {col: dset[col] for col in dset.column_names}
877+
for col, single_values in column_values_dict.items():
878+
self.assertListEqual(repeated_dset[col], single_values * 3)
879+
del repeated_dset
880+
881+
with tempfile.TemporaryDirectory() as tmp_dir:
882+
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
883+
with pytest.raises(ValueError):
884+
dset.repeat(None)
885+
886+
with tempfile.TemporaryDirectory() as tmp_dir:
887+
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
888+
repeated_dset = dset.repeat(0)
889+
self.assertEqual(len(repeated_dset), 0)
890+
del repeated_dset
891+
892+
with tempfile.TemporaryDirectory() as tmp_dir:
893+
with self._create_dummy_dataset(in_memory, tmp_dir, multiple_columns=True) as dset:
894+
repeated_dset = dset.repeat(-1)
895+
self.assertEqual(len(repeated_dset), 0)
896+
del repeated_dset
897+
872898
def test_flatten(self, in_memory):
873899
with tempfile.TemporaryDirectory() as tmp_dir:
874900
with Dataset.from_dict(

tests/test_iterable_dataset.py

+31
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
MappedExamplesIterable,
3434
RandomlyCyclingMultiSourcesExamplesIterable,
3535
RebatchedArrowExamplesIterable,
36+
RepeatExamplesIterable,
3637
SelectColumnsIterable,
3738
ShuffledDataSourcesArrowExamplesIterable,
3839
ShuffledDataSourcesExamplesIterable,
@@ -1167,6 +1168,28 @@ def test_take_examples_iterable():
11671168
assert_load_state_dict_resumes_iteration(take_ex_iterable)
11681169

11691170

1171+
@pytest.mark.parametrize(
1172+
"n, num_times",
1173+
[
1174+
(3, None),
1175+
(3, 3),
1176+
(3, 0),
1177+
],
1178+
)
1179+
def test_repeat_examples_iterable(n, num_times):
1180+
base_ex_iterable = ExamplesIterable(generate_examples_fn, {"n": n})
1181+
ex_iterable = RepeatExamplesIterable(base_ex_iterable, num_times=num_times)
1182+
all_examples = [x for _, x in generate_examples_fn(n=n)]
1183+
if num_times is not None:
1184+
expected = all_examples * max(num_times, 0)
1185+
assert [x for _, x in ex_iterable] == expected
1186+
else:
1187+
max_iters = 135
1188+
iterator = iter(ex_iterable)
1189+
for i in range(max_iters):
1190+
assert next(iterator)[1] == all_examples[i % len(all_examples)], f"iteration {i} failed,"
1191+
1192+
11701193
def test_vertically_concatenated_examples_iterable():
11711194
ex_iterable1 = ExamplesIterable(generate_examples_fn, {"label": 10})
11721195
ex_iterable2 = ExamplesIterable(generate_examples_fn, {"label": 5})
@@ -1742,6 +1765,14 @@ def test_iterable_dataset_take(dataset: IterableDataset, n):
17421765
assert list(take_dataset) == list(dataset)[:n]
17431766

17441767

1768+
@pytest.mark.parametrize("n", [0, 2])
1769+
def test_iterable_dataset_repeat(dataset: IterableDataset, n):
1770+
repeat_dataset = dataset.repeat(n)
1771+
assert isinstance(repeat_dataset._ex_iterable, RepeatExamplesIterable)
1772+
assert repeat_dataset._ex_iterable.num_times == n
1773+
assert list(repeat_dataset) == list(dataset) * n
1774+
1775+
17451776
def test_iterable_dataset_shard():
17461777
num_examples = 20
17471778
num_shards = 5

0 commit comments

Comments
 (0)