Skip to content

Commit 9e1c56e

Browse files
committed
[Flux] Enable checkpointing (#1195)
## Context: 1. Change flux-dev / flux-schnell model training to be ~30000 step based on current MAST training results 2. Enable checkpointing. We enabled final_layer reshard_after_forward to solve issue described [here](#1167 (comment)) ## Test If we run following 2 runs, the training loss curve should be identical with `deterministic = True`: 1. Without checkpoint save and load, total step=10 2. Save checkpoint at step 5, and load checkpoint at step 5, continue training Currently issue #1194 makes the training loss not strictly identical. To exclude the influence of #1194, we reset the seeds (by calling `set_deterministic()` at the beginning of step 6. Then the checkpoint save/load makes the training loss identical. <img width="1675" alt="Screenshot 2025-05-14 at 2 06 23 PM" src="https://github.com/user-attachments/assets/22882b71-378c-44fa-bd48-8a8f238aa1b0" />
1 parent 6ddc039 commit 9e1c56e

File tree

6 files changed

+44
-5
lines changed

6 files changed

+44
-5
lines changed

torchtitan/experiments/flux/README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,23 @@ Run the following command to train the model on a single GPU:
2323

2424
```
2525

26+
If you want to train with other model config, run the following command:
27+
```bash
28+
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
29+
```
30+
2631
## Supported Features
2732
- Parallelism: The model supports FSDP, HSDP for training on multiple GPUs.
2833
- Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training.
34+
- Distributed checkpointing and loading.
35+
- Notes on the current checkpointing implementation: Currently we need to enable `reshard_after_forward=True` before eval
36+
process, and set it back to `False` after eval process. The reason is that eval step only runs forward, but not backward,
37+
so FSDP reshard_after_forward plan would interfere with how parameters look like for the potential subsequent checkpointing step.
38+
2939

3040

3141
## TODO
3242
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
33-
- [ ] Support for distributed checkpointing and loading
3443
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
3544
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
45+
- [ ] Checkpointing followup: Merge resharding strategy in `flux/trainer.py` to `parallel_flux.py`

torchtitan/experiments/flux/dataset/flux_dataset.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def _process_cc12m_image(
5858
if resized_img.mode != "RGB":
5959
resized_img = resized_img.convert("RGB")
6060

61+
# Normalize the image to [-1, 1]
6162
np_img = np.array(resized_img).transpose((2, 0, 1))
6263
tensor_img = torch.tensor(np_img).float() / 255.0 * 2.0 - 1.0
6364

torchtitan/experiments/flux/train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,11 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
182182
or self.step == self.job_config.training.steps
183183
):
184184
model.eval()
185+
# We need to set reshard_after_forward before last forward pass.
186+
# So the model wieghts are sharded the same way for checkpoint saving.
187+
model.final_layer.set_reshard_after_forward(True)
185188
self.eval_step()
189+
model.final_layer.set_reshard_after_forward(False)
186190
model.train()
187191

188192
def eval_step(self, prompt: str = "A photo of a cat"):

torchtitan/experiments/flux/train_configs/debug_model.toml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
6464

6565
[activation_checkpoint]
6666
mode = "full"
67+
68+
[checkpoint]
69+
enable_checkpoint = false
70+
folder = "checkpoint"
71+
interval = 5
72+
model_weights_only = false
73+
export_dtype = "float32"
74+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

torchtitan/experiments/flux/train_configs/flux_dev_model.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ lr = 1e-4
2828
eps = 1e-8
2929

3030
[lr_scheduler]
31-
warmup_steps = 30_000 # lr scheduler warm up, normally 20% of the train steps
31+
warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
3232
decay_ratio = 0.0 # no decay
3333

3434
[training]
3535
batch_size = 4
3636
seq_len = 512
3737
max_norm = 1.0 # grad norm clipping
38-
steps = 300_000
38+
steps = 30_000
3939
compile = false
4040
dataset = "cc12m-wds"
4141
classifer_free_guidance_prob = 0.1
@@ -63,3 +63,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
6363

6464
[activation_checkpoint]
6565
mode = "full"
66+
67+
[checkpoint]
68+
enable_checkpoint = false
69+
folder = "checkpoint"
70+
interval = 1_000
71+
model_weights_only = false
72+
export_dtype = "float32"
73+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

torchtitan/experiments/flux/train_configs/flux_schnell_model.toml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,14 @@ lr = 1e-4
2828
eps = 1e-8
2929

3030
[lr_scheduler]
31-
warmup_steps = 30_000 # lr scheduler warm up, normally 20% of the train steps
31+
warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
3232
decay_ratio = 0.0 # no decay
3333

3434
[training]
3535
batch_size = 4
3636
seq_len = 512
3737
max_norm = 1.0 # grad norm clipping
38-
steps = 300_000
38+
steps = 30_000
3939
compile = false
4040
dataset = "cc12m-wds"
4141
classifer_free_guidance_prob = 0.1
@@ -63,3 +63,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"
6363

6464
[activation_checkpoint]
6565
mode = "full"
66+
67+
[checkpoint]
68+
enable_checkpoint = false
69+
folder = "checkpoint"
70+
interval = 1_000
71+
model_weights_only = false
72+
export_dtype = "float32"
73+
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]

0 commit comments

Comments
 (0)