Skip to content

add load checkpoint support for virtual table #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

bobbyliujb
Copy link

Summary:
X-link: pytorch/FBGEMM#4250

after all of the rebasing and landing, the trunk still missed some of the needed changes for checkpoint loading:

  • change create_virtual_table_global_metadata to respect local_weight_count on each rank, or just use the param size as number of rows on each rank
  • register register_load_state_dict_post_hook in ShardedEmbeddingCollection to let it ignore loading the weight tensor

Differential Revision:
D75843542

Privacy Context Container: L1138451

Summary:
X-link: pytorch/FBGEMM#4250

after all of the rebasing and landing, the trunk still missed some of the needed changes for checkpoint loading:
* change `create_virtual_table_global_metadata` to respect local_weight_count on each rank, or just use the param size as number of rows on each rank
* register register_load_state_dict_post_hook in ShardedEmbeddingCollection to let it ignore loading the weight tensor

Differential Revision:
D75843542

Privacy Context Container: L1138451
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jun 3, 2025
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D75843542

facebook-github-bot pushed a commit to pytorch/FBGEMM that referenced this pull request Jun 4, 2025
Summary:
X-link: pytorch/torchrec#3037

X-link: facebookresearch/FBGEMM#1329

Pull Request resolved: #4250

after all of the rebasing and landing, the trunk still missed some of the needed changes for checkpoint loading:
* change `create_virtual_table_global_metadata` to respect local_weight_count on each rank, or just use the param size as number of rows on each rank
* register register_load_state_dict_post_hook in ShardedEmbeddingCollection to let it ignore loading the weight tensor

Reviewed By: emlin

Differential Revision:
D75843542

Privacy Context Container: L1138451

fbshipit-source-id: 8b3c8d76bb2e7ba2137c8899de2c03d534f1365c
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants