Skip to content

DCP Checkpoint Load Fails for _extra_state when training in FSDP2 #1860

Open
@ajWithNucleus

Description

@ajWithNucleus

Describe the bug

DCP Checkpoint Load Fails for _extra_state when training in FSDP2

Steps/Code to reproduce bug

Save a DCP checkpoint and try loading it back when using FSDP

[rank4]: ValueError: Size mismatch between saved torch.Size([2322]) and current: torch.Size([4]) for model.diffusion_trans
former.layers.0.feed_forward.ffn._extra_state

Expected behavior

We should be able to load back the model properly to resume training.

Environment overview (please complete the following information)

  • Environment location: Baremetal
  • Method of Transformer Engine install: pip install, v2.3
  • If method of install is [Docker], provide docker pull & docker run commands used

Environment details

If NVIDIA docker image is used you don't need to specify these.
Otherwise, please provide:

  • OS version: Ubuntu 22.04
  • PyTorch version: 2.7
  • Python version: 3.12
  • Transformer Engine version: v2.3
  • CUDA version: 12.8
  • CUDNN version: ~9

Device details

  • GPU model: H100s

Additional context

Add any other context about the problem here.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions