Skip to content

[Flux] Enable checkpointing #1195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torchtitan/experiments/flux/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,23 @@ Run the following command to train the model on a single GPU:

```

If you want to train with other model config, run the following command:
```bash
CONFIG_FILE="./torchtitan/experiments/flux/train_configs/flux_schnell_model.toml" ./torchtitan/experiments/flux/run_train.sh
```

## Supported Features
- Parallelism: The model supports FSDP, HSDP for training on multiple GPUs.
- Activation checkpointing: The model uses activation checkpointing to reduce memory usage during training.
- Distributed checkpointing and loading.
- Notes on the current checkpointing implementation: Currently we need to enable `reshard_after_forward=True` before eval
process, and set it back to `False` after eval process. The reason is that eval step only runs forward, but not backward,
so FSDP reshard_after_forward plan would interfere with how parameters look like for the potential subsequent checkpointing step.



## TODO
- [ ] More parallesim support (Tensor Parallelism, Context Parallelism, etc)
- [ ] Support for distributed checkpointing and loading
- [ ] Implement the num_flops_per_token calculation in get_nparams_and_flops() function
- [ ] Implement test cases in CI for FLUX model. Adding more unit tests for FLUX model (eg, unit test for preprocessor, etc)
- [ ] Checkpointing followup: Merge resharding strategy in `flux/trainer.py` to `parallel_flux.py`
3 changes: 2 additions & 1 deletion torchtitan/experiments/flux/dataset/flux_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,9 @@ def _process_cc12m_image(
if resized_img.mode != "RGB":
resized_img = resized_img.convert("RGB")

# Normalize the image to [-1, 1]
np_img = np.array(resized_img).transpose((2, 0, 1))
tensor_img = torch.tensor(np_img).float() / 255.0
tensor_img = torch.tensor(np_img).float() / 255.0 * 2.0 - 1.0

# NOTE: The following commented code is an alternative way
# img_transform = transforms.Compose(
Expand Down
4 changes: 4 additions & 0 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
or self.step == self.job_config.training.steps
):
model.eval()
# We need to set reshard_after_forward before last forward pass.
# So the model wieghts are sharded the same way for checkpoint saving.
model.final_layer.set_reshard_after_forward(True)
self.eval_step()
model.final_layer.set_reshard_after_forward(False)
model.train()

def eval_step(self, prompt: str = "A photo of a cat"):
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/experiments/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 5
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
12 changes: 10 additions & 2 deletions torchtitan/experiments/flux/train_configs/flux_dev_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ lr = 1e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 30_000 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0 # no decay

[training]
batch_size = 4
seq_len = 512
max_norm = 1.0 # grad norm clipping
steps = 300_000
steps = 30_000
compile = false
dataset = "cc12m-wds"
classifer_free_guidance_prob = 0.1
Expand Down Expand Up @@ -63,3 +63,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 1_000
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]
12 changes: 10 additions & 2 deletions torchtitan/experiments/flux/train_configs/flux_schnell_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ lr = 1e-4
eps = 1e-8

[lr_scheduler]
warmup_steps = 30_000 # lr scheduler warm up, normally 20% of the train steps
warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0 # no decay

[training]
batch_size = 4
seq_len = 512
max_norm = 1.0 # grad norm clipping
steps = 300_000
steps = 30_000
compile = false
dataset = "cc12m-wds"
classifer_free_guidance_prob = 0.1
Expand Down Expand Up @@ -63,3 +63,11 @@ custom_args_module = "torchtitan.experiments.flux.flux_argparser"

[activation_checkpoint]
mode = "full"

[checkpoint]
enable_checkpoint = false
folder = "checkpoint"
interval = 1_000
model_weights_only = false
export_dtype = "float32"
async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"]