Skip to content

Commit e8ee24a

Browse files
authored
Fix resuming after ds.set_epoch(new_epoch) (#7451)
* fix resuming with new epoch * more readable states * add test * make style
1 parent 7ad7379 commit e8ee24a

File tree

5 files changed

+65
-33
lines changed

5 files changed

+65
-33
lines changed

src/datasets/commands/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def run(self):
167167
output_file = os.path.join(output_dir, f_name)
168168
os.makedirs(output_dir, exist_ok=True)
169169
self._logger.info(f"Adding directory {output_dir}")
170-
imports_to_builder_map.update({imp: output_dir for imp in tfds_imports})
170+
imports_to_builder_map.update(dict.fromkeys(tfds_imports, output_dir))
171171
else:
172172
# Utilities will be moved at the end
173173
utils_files.append(output_file)

src/datasets/dataset_dict.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def map(
931931
"""
932932
self._check_values_type()
933933
if cache_file_names is None:
934-
cache_file_names = {k: None for k in self}
934+
cache_file_names = dict.fromkeys(self)
935935

936936
dataset_dict = {}
937937
for split, dataset in self.items():
@@ -1051,7 +1051,7 @@ def filter(
10511051
"""
10521052
self._check_values_type()
10531053
if cache_file_names is None:
1054-
cache_file_names = {k: None for k in self}
1054+
cache_file_names = dict.fromkeys(self)
10551055
return DatasetDict(
10561056
{
10571057
k: dataset.filter(
@@ -1109,7 +1109,7 @@ def flatten_indices(
11091109
"""
11101110
self._check_values_type()
11111111
if cache_file_names is None:
1112-
cache_file_names = {k: None for k in self}
1112+
cache_file_names = dict.fromkeys(self)
11131113
return DatasetDict(
11141114
{
11151115
k: dataset.flatten_indices(
@@ -1176,7 +1176,7 @@ def sort(
11761176
"""
11771177
self._check_values_type()
11781178
if indices_cache_file_names is None:
1179-
indices_cache_file_names = {k: None for k in self}
1179+
indices_cache_file_names = dict.fromkeys(self)
11801180
return DatasetDict(
11811181
{
11821182
k: dataset.sort(
@@ -1254,13 +1254,13 @@ def shuffle(
12541254
raise ValueError("Please specify seed or seeds, but not both")
12551255
seeds = seed if seed is not None else seeds
12561256
if seeds is None:
1257-
seeds = {k: None for k in self}
1257+
seeds = dict.fromkeys(self)
12581258
elif not isinstance(seeds, dict):
1259-
seeds = {k: seeds for k in self}
1259+
seeds = dict.fromkeys(self, seeds)
12601260
if generators is None:
1261-
generators = {k: None for k in self}
1261+
generators = dict.fromkeys(self)
12621262
if indices_cache_file_names is None:
1263-
indices_cache_file_names = {k: None for k in self}
1263+
indices_cache_file_names = dict.fromkeys(self)
12641264
return DatasetDict(
12651265
{
12661266
k: dataset.shuffle(
@@ -1326,7 +1326,7 @@ def save_to_disk(
13261326
fs, _ = url_to_fs(dataset_dict_path, **(storage_options or {}))
13271327

13281328
if num_shards is None:
1329-
num_shards = {k: None for k in self}
1329+
num_shards = dict.fromkeys(self)
13301330
elif not isinstance(num_shards, dict):
13311331
raise ValueError(
13321332
"Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}"
@@ -1696,7 +1696,7 @@ def push_to_hub(
16961696
```
16971697
"""
16981698
if num_shards is None:
1699-
num_shards = {k: None for k in self}
1699+
num_shards = dict.fromkeys(self)
17001700
elif not isinstance(num_shards, dict):
17011701
raise ValueError(
17021702
"Please provide one `num_shards` per dataset in the dataset dictionary, e.g. {{'train': 128, 'test': 4}}"

src/datasets/formatting/formatting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ def __init__(self, pa_table: pa.Table, formatter: "Formatter"):
270270
self.pa_table = pa_table
271271
self.formatter = formatter
272272

273-
self.data = {key: None for key in pa_table.column_names}
273+
self.data = dict.fromkeys(pa_table.column_names)
274274
self.keys_to_format = set(self.data.keys())
275275

276276
def __len__(self):

src/datasets/iterable_dataset.py

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def __init__(self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs
212212
self.kwargs = kwargs
213213

214214
def _init_state_dict(self) -> dict:
215-
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0}
215+
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
216216
return self._state_dict
217217

218218
def __iter__(self):
@@ -250,7 +250,7 @@ def __init__(
250250
self.generator = deepcopy(generator)
251251

252252
def _init_state_dict(self) -> dict:
253-
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0}
253+
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
254254
return self._state_dict
255255

256256
def __iter__(self):
@@ -290,7 +290,7 @@ def iter_arrow(self):
290290
return self._iter_arrow
291291

292292
def _init_state_dict(self) -> dict:
293-
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0}
293+
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
294294
return self._state_dict
295295

296296
def __iter__(self):
@@ -357,7 +357,7 @@ def __init__(
357357
self.generator = deepcopy(generator)
358358

359359
def _init_state_dict(self) -> dict:
360-
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0}
360+
self._state_dict = {"shard_idx": 0, "shard_example_idx": 0, "type": self.__class__.__name__}
361361
return self._state_dict
362362

363363
def __iter__(self):
@@ -437,11 +437,12 @@ def features(self):
437437

438438
def _init_state_dict(self) -> dict:
439439
self._state_dict = {
440-
"ex_iterable": self.ex_iterable._init_state_dict(),
440+
"examples_iterable": self.ex_iterable._init_state_dict(),
441441
"previous_state": None,
442442
"batch_idx": 0,
443443
"num_chunks_since_previous_state": 0,
444444
"cropped_chunk_length": 0,
445+
"type": self.__class__.__name__,
445446
}
446447
return self._state_dict
447448

@@ -680,6 +681,7 @@ def _init_state_dict(self) -> dict:
680681
"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
681682
"previous_states": [None] * len(self.ex_iterables),
682683
"is_exhausted": [False] * len(self.ex_iterables),
684+
"type": self.__class__.__name__,
683685
}
684686
return self._state_dict
685687

@@ -778,6 +780,7 @@ def _init_state_dict(self) -> dict:
778780
self._state_dict = {
779781
"ex_iterable_idx": 0,
780782
"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
783+
"type": self.__class__.__name__,
781784
}
782785
return self._state_dict
783786

@@ -858,7 +861,10 @@ def features(self):
858861
return self.ex_iterables[0].features
859862

860863
def _init_state_dict(self) -> dict:
861-
self._state_dict = {"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables]}
864+
self._state_dict = {
865+
"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
866+
"type": self.__class__.__name__,
867+
}
862868
return self._state_dict
863869

864870
def __iter__(self):
@@ -960,6 +966,7 @@ def _init_state_dict(self) -> dict:
960966
"ex_iterables": [ex_iterable._init_state_dict() for ex_iterable in self.ex_iterables],
961967
"previous_states": [None] * len(self.ex_iterables),
962968
"is_exhausted": [False] * len(self.ex_iterables),
969+
"type": self.__class__.__name__,
963970
}
964971
return self._state_dict
965972

@@ -1060,10 +1067,11 @@ def features(self):
10601067

10611068
def _init_state_dict(self) -> dict:
10621069
self._state_dict = {
1063-
"ex_iterable": self.ex_iterable._init_state_dict(),
1070+
"examples_iterable": self.ex_iterable._init_state_dict(),
10641071
"previous_state": None,
10651072
"num_examples_since_previous_state": 0,
10661073
"previous_state_example_idx": 0,
1074+
"type": self.__class__.__name__,
10671075
}
10681076
return self._state_dict
10691077

@@ -1578,7 +1586,11 @@ def features(self):
15781586
return self.ex_iterable.features
15791587

15801588
def _init_state_dict(self) -> dict:
1581-
self._state_dict = {"skipped": False, "ex_iterable": self.ex_iterable._init_state_dict()}
1589+
self._state_dict = {
1590+
"skipped": False,
1591+
"examples_iterable": self.ex_iterable._init_state_dict(),
1592+
"type": self.__class__.__name__,
1593+
}
15821594
return self._state_dict
15831595

15841596
def __iter__(self):
@@ -1642,7 +1654,8 @@ def __init__(
16421654
def _init_state_dict(self) -> dict:
16431655
self._state_dict = {
16441656
"repeat_index": 0,
1645-
"ex_iterable": self.ex_iterable._init_state_dict(),
1657+
"examples_iterable": self.ex_iterable._init_state_dict(),
1658+
"type": self.__class__.__name__,
16461659
}
16471660
return self._state_dict
16481661

@@ -1655,7 +1668,7 @@ def __iter__(self):
16551668
repeat_index += 1
16561669
if self._state_dict:
16571670
self._state_dict["repeat_index"] = repeat_index
1658-
self._state_dict["ex_iterable"] = self.ex_iterable._init_state_dict()
1671+
self._state_dict["examples_iterable"] = self.ex_iterable._init_state_dict()
16591672

16601673
def shuffle_data_sources(self, generator: np.random.Generator) -> "RepeatExamplesIterable":
16611674
"""Shuffle the underlying iterable, then repeat."""
@@ -1697,7 +1710,11 @@ def features(self):
16971710
return self.ex_iterable.features
16981711

16991712
def _init_state_dict(self) -> dict:
1700-
self._state_dict = {"num_taken": 0, "ex_iterable": self.ex_iterable._init_state_dict()}
1713+
self._state_dict = {
1714+
"num_taken": 0,
1715+
"examples_iterable": self.ex_iterable._init_state_dict(),
1716+
"type": self.__class__.__name__,
1717+
}
17011718
return self._state_dict
17021719

17031720
def __iter__(self):
@@ -1956,9 +1973,8 @@ def __init__(
19561973
self._token_per_repo_id: dict[str, Union[str, bool, None]] = token_per_repo_id or {}
19571974
self._epoch: Union[int, "torch.Tensor"] = _maybe_share_with_torch_persistent_workers(0)
19581975
self._starting_state_dict: Optional[dict] = None
1959-
self._prepared_ex_iterable = self._prepare_ex_iterable_for_iteration()
1960-
self._state_dict = self._prepared_ex_iterable._init_state_dict()
1961-
_maybe_add_torch_iterable_dataset_parent_class(self.__class__)
1976+
self._prepare_ex_iterable_for_iteration() # set state_dict
1977+
_maybe_add_torch_iterable_dataset_parent_class(self.__class__) # subclass of torch IterableDataset
19621978

19631979
def state_dict(self) -> dict:
19641980
"""Get the current state_dict of the dataset.
@@ -2061,7 +2077,6 @@ def load_state_dict(self, state_dict: dict) -> None:
20612077
>>> dataloader.load_state_dict(state_dict) # uses ds.load_state_dict() under the hood
20622078
```
20632079
"""
2064-
self._prepared_ex_iterable.load_state_dict(state_dict)
20652080
self._starting_state_dict = state_dict
20662081

20672082
def __repr__(self):
@@ -2136,9 +2151,12 @@ def _iter_pytorch(self):
21362151
ex_iterable = ex_iterable.shard_data_sources(
21372152
num_shards=worker_info.num_workers, index=worker_info.id, contiguous=False
21382153
)
2139-
self._state_dict = ex_iterable._init_state_dict()
2140-
if self._starting_state_dict:
2141-
ex_iterable.load_state_dict(self._starting_state_dict)
2154+
self._state_dict = {
2155+
"examples_iterable": ex_iterable._init_state_dict(),
2156+
"epoch": self.epoch,
2157+
}
2158+
if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]:
2159+
ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"])
21422160

21432161
if self._formatting and (ex_iterable.iter_arrow or self._formatting.is_table):
21442162
formatter = get_formatter(self._formatting.format_type, features=self.features)
@@ -2216,9 +2234,12 @@ def _prepare_ex_iterable_for_iteration(
22162234
token_per_repo_id=self._token_per_repo_id,
22172235
)
22182236

2219-
self._state_dict = ex_iterable._init_state_dict()
2220-
if self._starting_state_dict:
2221-
ex_iterable.load_state_dict(self._starting_state_dict)
2237+
self._state_dict = {
2238+
"examples_iterable": ex_iterable._init_state_dict(),
2239+
"epoch": self.epoch,
2240+
}
2241+
if self._starting_state_dict and self.epoch == self._starting_state_dict["epoch"]:
2242+
ex_iterable.load_state_dict(self._starting_state_dict["examples_iterable"])
22222243
return ex_iterable
22232244

22242245
def __iter__(self):

tests/test_iterable_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,17 @@ def test_iterable_dataset_set_epoch(dataset: IterableDataset):
15811581
assert dataset._epoch == 42
15821582

15831583

1584+
def test_iterable_dataset_set_epoch_resuming(dataset: IterableDataset):
1585+
dataset_length = len(list(dataset))
1586+
assert len(list(dataset)) == dataset_length > 0
1587+
dataset.load_state_dict(dataset.state_dict())
1588+
assert len(list(dataset)) == 0
1589+
dataset.set_epoch(1)
1590+
assert len(list(dataset)) == dataset_length > 0
1591+
dataset.load_state_dict(dataset.state_dict())
1592+
assert len(list(dataset)) == 0
1593+
1594+
15841595
@pytest.mark.parametrize("seed", [None, 42, 1337])
15851596
@pytest.mark.parametrize("epoch", [None, 0, 1, 10])
15861597
def test_iterable_dataset_set_epoch_of_shuffled_dataset(dataset: IterableDataset, seed, epoch):

0 commit comments

Comments
 (0)