Open
Description
Describe the bug
DCP Checkpoint Load Fails for _extra_state when training in FSDP2
Steps/Code to reproduce bug
Save a DCP checkpoint and try loading it back when using FSDP
[rank4]: ValueError: Size mismatch between saved torch.Size([2322]) and current: torch.Size([4]) for model.diffusion_trans
former.layers.0.feed_forward.ffn._extra_state
Expected behavior
We should be able to load back the model properly to resume training.
Environment overview (please complete the following information)
- Environment location: Baremetal
- Method of Transformer Engine install: pip install, v2.3
- If method of install is [Docker], provide
docker pull
&docker run
commands used
Environment details
If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:
- OS version: Ubuntu 22.04
- PyTorch version: 2.7
- Python version: 3.12
- Transformer Engine version: v2.3
- CUDA version: 12.8
- CUDNN version: ~9
Device details
- GPU model: H100s
Additional context
Add any other context about the problem here.