Skip to content

[FSDP2] fully_shard(reshard_after_forward=False) for root #1253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

weifengpy
Copy link
Contributor

for root model, we should always set reshard_after_forward=False regardless of pp_enabled, because root model parameters will be used in backward immeidately. no need to reshard and all-gather

this is also a future proof PR for pytorch side change pytorch/pytorch#154704

the code is propagated to NeMo, I will follow up on it as well:https://github.com/NVIDIA/NeMo/blob/373288be4bea04ece4ceb189bc7448a0ca02aa97/nemo/lightning/pytorch/strategies/utils.py#L524C7-L524C66

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 2, 2025
@weifengpy weifengpy requested review from tianyu-l and wwwjn June 2, 2025 04:22
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is covered by #1094, which is blocked because the memory estimation tool doesn't support fully_shard([list of modules]).

Is there a chance you could help unblock?

Also I actually wonder if this could be a memory-throughput trade-off, rather than pure throughput gain -- IIRC (without full AC) the peak memory is at loss forward + backward, wouldn't always setting reshard_after_forward=False on the root/last module increase peak memory?

@weifengpy
Copy link
Contributor Author

I think this is covered by #1094, which is blocked because the memory estimation tool doesn't support fully_shard([list of modules]).

Is there a chance you could help unblock?

Good to know this is covered. I am tight on bandwidth maybe we need to wait for Sanket as you suggested in the PR

wouldn't always setting reshard_after_forward=False on the root/last module increase peak memory?

Good point! it's a trade off indeed.

currently for root model, no matter user config reshard_after_forward=True/False, FSDP2 internally override it to False. That's why I drafted this PR to reflect what's truly going on inside fsdp2. does it make sense? This also applies to #1094. I don't want users to think root modules are resharded if they config reshard_after_forward=True

@tianyu-l
Copy link
Contributor

tianyu-l commented Jun 3, 2025

maybe we need to wait for Sanket as you suggested in the PR

sounds good

FSDP2 internally override it to False. That's why I drafted this PR to reflect what's truly going on inside fsdp2. does it make sense?

If it's a trade-off, then the forced behavior only makes sense when memory is not a concern. I'd ask why not giving this control to users?

This also applies to #1094.

After #1094, everything in the root module will be manually wrapped, so the root module behavior doesn't matter any more.

@weifengpy weifengpy closed this Jun 3, 2025
@weifengpy
Copy link
Contributor Author

If it's a trade-off, then the forced behavior only makes sense when memory is not a concern. I'd ask why not giving this control to users?

After #1094, everything in the root module will be manually wrapped, so the root module behavior doesn't matter any more.

got you. thanks for explaining. closing this PR to leave the option to user

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants