-
Notifications
You must be signed in to change notification settings - Fork 505
Add padding in dynamic sharding for tensors before all2all #2944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aporialiao
wants to merge
1
commit into
pytorch:main
Choose a base branch
from
aporialiao:export-D74150894
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This pull request was exported from Phabricator. Differential Revision: D74150894 |
5f08838
to
e98680a
Compare
aporialiao
added a commit
to aporialiao/torchrec
that referenced
this pull request
May 5, 2025
) Summary: Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. This diff: 1. adds the logic for padding tensors in both dimensions 2. adds logic to remove the padding when updating the state dict after resharding 3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing 1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions. Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g. ``` t = [[1, 2] [3, 4]] max_dim_0 = 4 max_dim_1 = 3 t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1) print(t) >>> [[1, 2, 0, 0] [3, 4, 0, 0] [0, 0, 0, 0]] ``` Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the `output_tensor` passing into a2a has large enough size. > NOTE: This will be optimized later to go through only shard sizes that are being redistributed. Differential Revision: D74150894
This pull request was exported from Phabricator. Differential Revision: D74150894 |
Summary: Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. This diff: 1. adds the logic for padding tensors in both dimensions 2. adds logic to remove the padding when updating the state dict after resharding 3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing 1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions. Padding leverages `nn.Functional.pad`, and pads tensors with value 0 on the right and bottom: e.g. ``` t = [[1, 2] [3, 4]] max_dim_0 = 4 max_dim_1 = 3 t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1) print(t) >>> [[1, 2, 0, 0] [3, 4, 0, 0] [0, 0, 0, 0]] ``` Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the `output_tensor` passing into a2a has large enough size. > NOTE: This will be optimized later to go through only shard sizes that are being redistributed. Differential Revision: D74150894
e98680a
to
c406188
Compare
This pull request was exported from Phabricator. Differential Revision: D74150894 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
fb-exported
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into
all_to_all_single
collective to ensure we only call the expensive collective once.This diff:
Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g.
Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the
output_tensor
passing into a2a has large enough size.Differential Revision: D74150894