Skip to content

Minimal reproduce of error when saving state dict after looping #1358

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

wesleytruong
Copy link
Contributor

When looping iterable dataset, state dict is not being saved correctly as shown by:

save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 1, 'num_chunks_since_previous_state': 1, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 1}
load state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'previous_state': None, 'batch_idx': 0, 'num_chunks_since_previous_state': 0, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 0}

The error can be reproduced when switching between batch sizes of less than and greater than 32.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 30, 2025
@tianyu-l
Copy link
Contributor

Hi @divyanshk
We are observing inconsistent save & load results using StatefulDataloader
https://github.com/pytorch/torchtitan/blob/main/torchtitan/components/dataloader.py#L98

It only happens when we do save after the dataset is re-looped.
I'm not sure if it's because we didn't do things right. Do you think we can get help on this?

@divyanshk
Copy link

@wesleytruong Can you confirm if the "load state dict: " in the description above is that of a fresh iterator ?

@wesleytruong
Copy link
Contributor Author

wesleytruong commented Jul 2, 2025

@wesleytruong Can you confirm if the "load state dict: " in the description above is that of a fresh iterator ?

Yes, the huggingface IterableDataset's state dict in that print statement is printed when load_state_dict(state) is called on a freshly initialized dataloader.

@divyanshk
Copy link

@wesleytruong Got it! Getting a fresh iterator after the saved iterator has completed is expected.

Can you turn on persistent_workers in the dataloader

super().__init__(dataset, batch_size, collate_fn=collate_fn)
and see if this solves for your use-case ?

@wesleytruong
Copy link
Contributor Author

@wesleytruong Got it! Getting a fresh iterator after the saved iterator has completed is expected.

Can you turn on persistent_workers in the dataloader

super().__init__(dataset, batch_size, collate_fn=collate_fn)

and see if this solves for your use-case ?

No, turning on persistent_workers on with num_workers=1 doesn't seem to fix this issue. The state dict trace becomes

save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'previous_state': None, 'batch_idx': 0, 'num_chunks_since_previous_state': 0, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 0}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 1, 'num_chunks_since_previous_state': 1, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 1}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 2, 'num_chunks_since_previous_state': 2, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 2}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 3, 'num_chunks_since_previous_state': 3, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 3}
load state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'previous_state': None, 'batch_idx': 0, 'num_chunks_since_previous_state': 0, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 0}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'previous_state': None, 'batch_idx': 0, 'num_chunks_since_previous_state': 0, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 0}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 4, 'num_chunks_since_previous_state': 4, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 4}
save state dict:  {'examples_iterable': {'examples_iterable': {'shard_idx': 0, 'shard_example_idx': 8, 'type': 'ArrowExamplesIterable'}, 'previous_state': {'shard_idx': 0, 'shard_example_idx': 0, 'type': 'ArrowExamplesIterable'}, 'batch_idx': 2, 'num_chunks_since_previous_state': 2, 'cropped_chunk_length': 0, 'type': 'RebatchedArrowExamplesIterable'}, 'epoch': 1}

@divyanshk
Copy link

@wesleytruong I am not able to follow from the print statements.

So that I can better understand, what is the expected behavior here? Do we want states to be the same ? If so, do we have a way to restart from the top on the next next(..) call. Right now the dataloader gets a fresh iterator if the loaded iterator is finished. That is how we end up continuing to give items on doing next(..) after loading, without that after loading the iterator it would hit StopIteration or some other undefined behavior.

@wesleytruong
Copy link
Contributor Author

@divyanshk Sorry for the confusing print statement, you can ignore it if it is not relevant, I included it since I didn't understand why there seem to be more calls to state_dict and load_state_dict after turning on persistent workers, and was wondering if it meant something to you.

The expected behavior is that the state dict and load state dict should be the same since they should be generated by copying the state dict from one dataloader and using that to load another dataloader from this checkpoint.
You're right that when the dataset finishes, the dataloader's iter function retrieves a new IterableDataset iterator to continue retrieving data from the beginning of the dataset, and the way that we renew this iterator seems to work both for map-style datasets and iterable-style datasets currently.

The problem I am encountering is related to the state dict and load state dict not being symmetrical. In the test there is stage 1 where a single dataloader iterates on the dataset. The state dict is then retrieved from this dataloader and used to load another dataloader from state. In stage 2, these two dataloaders retrieve data in parallel, and we expect their data to be parallel. The issue is that when stage 1 takes longer than 1 epoch to complete, the second dataloader starts loading data from the beginning due to its state dict that it is loading from being empty. We believe the loss of the state dict's information could be due to the parent call to StatefulDataLoader in state_dict function of ParallelDataLoader.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Data loader's state_dict is being lost between being saved and loaded when the dataset loops
4 participants