Skip to content

Commit ee0264c

Browse files
Jianbo Liufacebook-github-bot
Jianbo Liu
authored andcommitted
add load checkpoint support for virtual table (#4250)
Summary: X-link: pytorch/torchrec#3037 X-link: facebookresearch/FBGEMM#1329 Pull Request resolved: #4250 after all of the rebasing and landing, the trunk still missed some of the needed changes for checkpoint loading: * change `create_virtual_table_global_metadata` to respect local_weight_count on each rank, or just use the param size as number of rows on each rank * register register_load_state_dict_post_hook in ShardedEmbeddingCollection to let it ignore loading the weight tensor Reviewed By: emlin Differential Revision: D75843542 Privacy Context Container: L1138451 fbshipit-source-id: 8b3c8d76bb2e7ba2137c8899de2c03d534f1365c
1 parent 2d3e7e5 commit ee0264c

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1959,7 +1959,7 @@ def split_optimizer_states(
19591959
# init for checkpointing loading
19601960
assert (
19611961
self._cached_kvzch_data is not None
1962-
and self._cached_kvzch_data.cached_optimizer_state_per_table is not None
1962+
and self._cached_kvzch_data.cached_optimizer_state_per_table
19631963
), "optimizer state is not initialized for load checkpointing"
19641964
return self._cached_kvzch_data.cached_optimizer_state_per_table
19651965

@@ -2365,13 +2365,17 @@ def streaming_write_weight_and_id_per_table(
23652365
D_rounded = pad4(weight_state.size(1)) # padded to 4 bytes alignment
23662366
dtype = self.weights_precision.as_dtype()
23672367
kvt = torch.classes.fbgemm.KVTensorWrapper(
2368-
db=self.ssd_db,
23692368
shape=[weight_state.size(0), self.cache_row_dim],
23702369
dtype=dtype,
23712370
row_offset=row_offset,
23722371
snapshot_handle=None,
23732372
sorted_indices=id_tensor,
23742373
)
2374+
(
2375+
kvt.set_embedding_rocks_dp_wrapper(self.ssd_db)
2376+
if self.backend_type == BackendType.SSD
2377+
else kvt.set_dram_db_wrapper(self.ssd_db)
2378+
)
23752379
# TODO: make chunk_size configurable or dynamic
23762380
chunk_size = 10000
23772381
row = weight_state.size(0)
@@ -2417,9 +2421,7 @@ def enable_load_state_dict_mode(self) -> None:
24172421
logging.info(
24182422
f"for checkpoint loading, table {i}, weight_state shape is {weight_state.shape}, opt_state shape is {opt_state.shape}"
24192423
)
2420-
id_tensor = torch.zeros(
2421-
(self.local_weight_counts[i], 1), dtype=torch.int64, device="cpu"
2422-
)
2424+
id_tensor = torch.zeros((rows, 1), dtype=torch.int64, device="cpu")
24232425
# pyre-ignore [16]
24242426
self._cached_kvzch_data.cached_id_tensor_per_table.append(id_tensor)
24252427
# pyre-ignore [16]

0 commit comments

Comments
 (0)