Skip to content

[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985

Open
roycho96 wants to merge 2 commits intohuggingface:mainfrom
roycho96:fix/mixed-precision
Open

[FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError#3985
roycho96 wants to merge 2 commits intohuggingface:mainfrom
roycho96:fix/mixed-precision

Conversation

@roycho96
Copy link
Contributor

What does this PR do?

When mixed_precision is enabled, casts model parameters to uniform dtype before fully_shard() to prevent _init_mp_dtypes() AssertionError.

Problem

FSDP2's _init_mp_dtypes() requires uniform orig_dtype across all trainable parameters in a param group. With mixed dtypes, the first forward call crashes:

AssertionError: FSDP expects uniform original parameter dtype but got {torch.bfloat16, torch.float32}

FSDP2's fsdp2_prepare_model() currently passes the mixed-dtype model directly to fully_shard() without normalizing dtypes.

Fix

Cast all parameters to the mixed precision param_dtype before fully_shard(), after model_has_params4bit detection. Params4bit models are skipped to avoid destroying quantized weights.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@SunMarc

@roycho96 roycho96 changed the title Fix/mixed precision [FSDP2] Cast model to uniform dtype before fully_shard to fix mixed-dtype AssertionError Mar 20, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant