-
Couldn't load subscription status.
- Fork 567
Add padding in dynamic sharding for tensors before all2all #2944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D74150894 |
5f08838 to
e98680a
Compare
…rch#2944) Summary: Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. This diff: 1. adds the logic for padding tensors in both dimensions 2. adds logic to remove the padding when updating the state dict after resharding 3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing 1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions. Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g. ``` t = [[1, 2] [3, 4]] max_dim_0 = 4 max_dim_1 = 3 t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1) print(t) >>> [[1, 2, 0, 0] [3, 4, 0, 0] [0, 0, 0, 0]] ``` Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the `output_tensor` passing into a2a has large enough size. > NOTE: This will be optimized later to go through only shard sizes that are being redistributed. Differential Revision: D74150894
|
This pull request was exported from Phabricator. Differential Revision: D74150894 |
e98680a to
c406188
Compare
|
This pull request was exported from Phabricator. Differential Revision: D74150894 |
Summary:
Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once.
This diff:
1. adds the logic for padding tensors in both dimensions
2. adds logic to remove the padding when updating the state dict after resharding
3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing
1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions.
Padding leverages `nn.Functional.pad`, and pads tensors with value 0 on the right and bottom: e.g.
```
t = [[1, 2]
[3, 4]]
max_dim_0 = 4
max_dim_1 = 3
t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1)
print(t)
>>> [[1, 2, 0, 0]
[3, 4, 0, 0]
[0, 0, 0, 0]]
```
Max dimensions for dim 0 and 1 are determined by going through all shard sizes that are being redistrbuted. This is because we need to ensure the `output_tensor` passing into a2a has large enough size.
Differential Revision: D74150894
c406188 to
43b793a
Compare
|
This pull request was exported from Phabricator. Differential Revision: D74150894 |
…rch#2944) Summary: Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into `all_to_all_single` collective to ensure we only call the expensive collective once. This diff: 1. adds the logic for padding tensors in both dimensions 2. adds logic to remove the padding when updating the state dict after resharding 3. Removes the original implentation of concatenating input tensors by dim 1 (which assumes dim 0 can be variable but dim 1 is consistent across all shards) and transposing 1. This ensures that the existing CW unit test is leveraging the padding logic, as CW unit test was the previous one that failed due to inconsistent dimensions. Padding leverages `nn.Functional.pad`, and pads tensors with value 0 on the right and bottom: e.g. ``` t = [[1, 2] [3, 4]] max_dim_0 = 4 max_dim_1 = 3 t = pad_tensor_to_max_dims(t, max_dim_0, max_dim_1) print(t) >>> [[1, 2, 0, 0] [3, 4, 0, 0] [0, 0, 0, 0]] ``` Max dimensions for dim 0 and 1 are determined by going through all shard sizes that are being redistrbuted. This is because we need to ensure the `output_tensor` passing into a2a has large enough size. Differential Revision: D74150894
Summary:
Given we can't expect shards in an embedding module to have the same dimensions for both dim 0 and dim 1, we have to pad the tensors passed into
all_to_all_singlecollective to ensure we only call the expensive collective once.This diff:
Padding leverages nn.Functional.pad, and pads tensors with value 0 on the right and bottom: e.g.
Max dimensions for dim 0 and 1 are determined by going through all shard sizes of an embedding module. This is because we need to ensure the
output_tensorpassing into a2a has large enough size.Differential Revision: D74150894