-
Notifications
You must be signed in to change notification settings - Fork 386
[Flux] Enable checkpointing #1195
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest that you split the "fix" of parallelize_flux.py into another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please address comments before merge.
Thank you @fegin for reminding, I created a separate PR for this. I will make separate PRs for changes for later changes as well, thanks for reminding! |
## Context: 1. Change flux-dev / flux-schnell model training to be ~30000 step based on current MAST training results 2. Enable checkpointing. We enabled final_layer reshard_after_forward to solve issue described [here](#1167 (comment)) ## Test If we run following 2 runs, the training loss curve should be identical with `deterministic = True`: 1. Without checkpoint save and load, total step=10 2. Save checkpoint at step 5, and load checkpoint at step 5, continue training Currently issue #1194 makes the training loss not strictly identical. To exclude the influence of #1194, we reset the seeds (by calling `set_deterministic()` at the beginning of step 6. Then the checkpoint save/load makes the training loss identical. <img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM" src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0" />
Context:
Test
If we run following 2 runs, the training loss curve should be identical with
deterministic = True
:Currently issue #1194 makes the training loss not strictly identical. To exclude the influence of #1194, we reset the seeds (by calling
set_deterministic()
at the beginning of step 6. Then the checkpoint save/load makes the training loss identical.