Skip to content

[Flux] Enable checkpointing #1195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2025
Merged

[Flux] Enable checkpointing #1195

merged 6 commits into from
May 15, 2025

Conversation

wwwjn
Copy link
Contributor

@wwwjn wwwjn commented May 14, 2025

Context:

  1. Change flux-dev / flux-schnell model training to be ~30000 step based on current MAST training results
  2. Enable checkpointing. We enabled final_layer reshard_after_forward to solve issue described here

Test

If we run following 2 runs, the training loss curve should be identical with deterministic = True:

  1. Without checkpoint save and load, total step=10
  2. Save checkpoint at step 5, and load checkpoint at step 5, continue training

Currently issue #1194 makes the training loss not strictly identical. To exclude the influence of #1194, we reset the seeds (by calling set_deterministic() at the beginning of step 6. Then the checkpoint save/load makes the training loss identical.

Screenshot 2025-05-14 at 2 06 23 PM

@wwwjn wwwjn requested a review from tianyu-l May 14, 2025 21:06
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 14, 2025
@wwwjn wwwjn requested a review from fegin May 14, 2025 21:06
Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest that you split the "fix" of parallelize_flux.py into another PR.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please address comments before merge.

@wwwjn
Copy link
Contributor Author

wwwjn commented May 15, 2025

I suggest that you split the "fix" of parallelize_flux.py into another PR.

Thank you @fegin for reminding, I created a separate PR for this. I will make separate PRs for changes for later changes as well, thanks for reminding!

@wwwjn wwwjn merged commit 0104e39 into main May 15, 2025
6 checks passed
@tianyu-l tianyu-l deleted the flux-ci-2 branch May 15, 2025 23:42
wwwjn added a commit that referenced this pull request May 16, 2025
## Context:
1. Change flux-dev / flux-schnell model training to be ~30000 step based
on current MAST training results
2. Enable checkpointing. We enabled final_layer reshard_after_forward to
solve issue described
[here](#1167 (comment))

## Test
If we run following 2 runs, the training loss curve should be identical
with `deterministic = True`:
1. Without checkpoint save and load, total step=10
2. Save checkpoint at step 5, and load checkpoint at step 5, continue
training

Currently issue #1194 makes the training loss not strictly identical. To
exclude the influence of #1194, we reset the seeds (by calling
`set_deterministic()` at the beginning of step 6. Then the checkpoint
save/load makes the training loss identical.

<img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM"
src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants