@@ -212,7 +212,7 @@ def __init__(self, generate_examples_fn: Callable[..., tuple[Key, dict]], kwargs
212
212
self .kwargs = kwargs
213
213
214
214
def _init_state_dict (self ) -> dict :
215
- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
215
+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
216
216
return self ._state_dict
217
217
218
218
def __iter__ (self ):
@@ -250,7 +250,7 @@ def __init__(
250
250
self .generator = deepcopy (generator )
251
251
252
252
def _init_state_dict (self ) -> dict :
253
- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
253
+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
254
254
return self ._state_dict
255
255
256
256
def __iter__ (self ):
@@ -290,7 +290,7 @@ def iter_arrow(self):
290
290
return self ._iter_arrow
291
291
292
292
def _init_state_dict (self ) -> dict :
293
- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
293
+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
294
294
return self ._state_dict
295
295
296
296
def __iter__ (self ):
@@ -357,7 +357,7 @@ def __init__(
357
357
self .generator = deepcopy (generator )
358
358
359
359
def _init_state_dict (self ) -> dict :
360
- self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 }
360
+ self ._state_dict = {"shard_idx" : 0 , "shard_example_idx" : 0 , "type" : self . __class__ . __name__ }
361
361
return self ._state_dict
362
362
363
363
def __iter__ (self ):
@@ -437,11 +437,12 @@ def features(self):
437
437
438
438
def _init_state_dict (self ) -> dict :
439
439
self ._state_dict = {
440
- "ex_iterable " : self .ex_iterable ._init_state_dict (),
440
+ "examples_iterable " : self .ex_iterable ._init_state_dict (),
441
441
"previous_state" : None ,
442
442
"batch_idx" : 0 ,
443
443
"num_chunks_since_previous_state" : 0 ,
444
444
"cropped_chunk_length" : 0 ,
445
+ "type" : self .__class__ .__name__ ,
445
446
}
446
447
return self ._state_dict
447
448
@@ -680,6 +681,7 @@ def _init_state_dict(self) -> dict:
680
681
"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
681
682
"previous_states" : [None ] * len (self .ex_iterables ),
682
683
"is_exhausted" : [False ] * len (self .ex_iterables ),
684
+ "type" : self .__class__ .__name__ ,
683
685
}
684
686
return self ._state_dict
685
687
@@ -778,6 +780,7 @@ def _init_state_dict(self) -> dict:
778
780
self ._state_dict = {
779
781
"ex_iterable_idx" : 0 ,
780
782
"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
783
+ "type" : self .__class__ .__name__ ,
781
784
}
782
785
return self ._state_dict
783
786
@@ -858,7 +861,10 @@ def features(self):
858
861
return self .ex_iterables [0 ].features
859
862
860
863
def _init_state_dict (self ) -> dict :
861
- self ._state_dict = {"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ]}
864
+ self ._state_dict = {
865
+ "ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
866
+ "type" : self .__class__ .__name__ ,
867
+ }
862
868
return self ._state_dict
863
869
864
870
def __iter__ (self ):
@@ -960,6 +966,7 @@ def _init_state_dict(self) -> dict:
960
966
"ex_iterables" : [ex_iterable ._init_state_dict () for ex_iterable in self .ex_iterables ],
961
967
"previous_states" : [None ] * len (self .ex_iterables ),
962
968
"is_exhausted" : [False ] * len (self .ex_iterables ),
969
+ "type" : self .__class__ .__name__ ,
963
970
}
964
971
return self ._state_dict
965
972
@@ -1060,10 +1067,11 @@ def features(self):
1060
1067
1061
1068
def _init_state_dict (self ) -> dict :
1062
1069
self ._state_dict = {
1063
- "ex_iterable " : self .ex_iterable ._init_state_dict (),
1070
+ "examples_iterable " : self .ex_iterable ._init_state_dict (),
1064
1071
"previous_state" : None ,
1065
1072
"num_examples_since_previous_state" : 0 ,
1066
1073
"previous_state_example_idx" : 0 ,
1074
+ "type" : self .__class__ .__name__ ,
1067
1075
}
1068
1076
return self ._state_dict
1069
1077
@@ -1578,7 +1586,11 @@ def features(self):
1578
1586
return self .ex_iterable .features
1579
1587
1580
1588
def _init_state_dict (self ) -> dict :
1581
- self ._state_dict = {"skipped" : False , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
1589
+ self ._state_dict = {
1590
+ "skipped" : False ,
1591
+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1592
+ "type" : self .__class__ .__name__ ,
1593
+ }
1582
1594
return self ._state_dict
1583
1595
1584
1596
def __iter__ (self ):
@@ -1642,7 +1654,8 @@ def __init__(
1642
1654
def _init_state_dict (self ) -> dict :
1643
1655
self ._state_dict = {
1644
1656
"repeat_index" : 0 ,
1645
- "ex_iterable" : self .ex_iterable ._init_state_dict (),
1657
+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1658
+ "type" : self .__class__ .__name__ ,
1646
1659
}
1647
1660
return self ._state_dict
1648
1661
@@ -1655,7 +1668,7 @@ def __iter__(self):
1655
1668
repeat_index += 1
1656
1669
if self ._state_dict :
1657
1670
self ._state_dict ["repeat_index" ] = repeat_index
1658
- self ._state_dict ["ex_iterable " ] = self .ex_iterable ._init_state_dict ()
1671
+ self ._state_dict ["examples_iterable " ] = self .ex_iterable ._init_state_dict ()
1659
1672
1660
1673
def shuffle_data_sources (self , generator : np .random .Generator ) -> "RepeatExamplesIterable" :
1661
1674
"""Shuffle the underlying iterable, then repeat."""
@@ -1697,7 +1710,11 @@ def features(self):
1697
1710
return self .ex_iterable .features
1698
1711
1699
1712
def _init_state_dict (self ) -> dict :
1700
- self ._state_dict = {"num_taken" : 0 , "ex_iterable" : self .ex_iterable ._init_state_dict ()}
1713
+ self ._state_dict = {
1714
+ "num_taken" : 0 ,
1715
+ "examples_iterable" : self .ex_iterable ._init_state_dict (),
1716
+ "type" : self .__class__ .__name__ ,
1717
+ }
1701
1718
return self ._state_dict
1702
1719
1703
1720
def __iter__ (self ):
@@ -1956,9 +1973,8 @@ def __init__(
1956
1973
self ._token_per_repo_id : dict [str , Union [str , bool , None ]] = token_per_repo_id or {}
1957
1974
self ._epoch : Union [int , "torch.Tensor" ] = _maybe_share_with_torch_persistent_workers (0 )
1958
1975
self ._starting_state_dict : Optional [dict ] = None
1959
- self ._prepared_ex_iterable = self ._prepare_ex_iterable_for_iteration ()
1960
- self ._state_dict = self ._prepared_ex_iterable ._init_state_dict ()
1961
- _maybe_add_torch_iterable_dataset_parent_class (self .__class__ )
1976
+ self ._prepare_ex_iterable_for_iteration () # set state_dict
1977
+ _maybe_add_torch_iterable_dataset_parent_class (self .__class__ ) # subclass of torch IterableDataset
1962
1978
1963
1979
def state_dict (self ) -> dict :
1964
1980
"""Get the current state_dict of the dataset.
@@ -2061,7 +2077,6 @@ def load_state_dict(self, state_dict: dict) -> None:
2061
2077
>>> dataloader.load_state_dict(state_dict) # uses ds.load_state_dict() under the hood
2062
2078
```
2063
2079
"""
2064
- self ._prepared_ex_iterable .load_state_dict (state_dict )
2065
2080
self ._starting_state_dict = state_dict
2066
2081
2067
2082
def __repr__ (self ):
@@ -2136,9 +2151,12 @@ def _iter_pytorch(self):
2136
2151
ex_iterable = ex_iterable .shard_data_sources (
2137
2152
num_shards = worker_info .num_workers , index = worker_info .id , contiguous = False
2138
2153
)
2139
- self ._state_dict = ex_iterable ._init_state_dict ()
2140
- if self ._starting_state_dict :
2141
- ex_iterable .load_state_dict (self ._starting_state_dict )
2154
+ self ._state_dict = {
2155
+ "examples_iterable" : ex_iterable ._init_state_dict (),
2156
+ "epoch" : self .epoch ,
2157
+ }
2158
+ if self ._starting_state_dict and self .epoch == self ._starting_state_dict ["epoch" ]:
2159
+ ex_iterable .load_state_dict (self ._starting_state_dict ["examples_iterable" ])
2142
2160
2143
2161
if self ._formatting and (ex_iterable .iter_arrow or self ._formatting .is_table ):
2144
2162
formatter = get_formatter (self ._formatting .format_type , features = self .features )
@@ -2216,9 +2234,12 @@ def _prepare_ex_iterable_for_iteration(
2216
2234
token_per_repo_id = self ._token_per_repo_id ,
2217
2235
)
2218
2236
2219
- self ._state_dict = ex_iterable ._init_state_dict ()
2220
- if self ._starting_state_dict :
2221
- ex_iterable .load_state_dict (self ._starting_state_dict )
2237
+ self ._state_dict = {
2238
+ "examples_iterable" : ex_iterable ._init_state_dict (),
2239
+ "epoch" : self .epoch ,
2240
+ }
2241
+ if self ._starting_state_dict and self .epoch == self ._starting_state_dict ["epoch" ]:
2242
+ ex_iterable .load_state_dict (self ._starting_state_dict ["examples_iterable" ])
2222
2243
return ex_iterable
2223
2244
2224
2245
def __iter__ (self ):
0 commit comments