Skip to content

Commit b2a33d1

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Adding function to create a snapshot and exposing it from EmbeddingRocksDBWrapper (#3024)
Summary: Pull Request resolved: #3024 X-link: pytorch/FBGEMM#4223 X-link: facebookresearch/FBGEMM#1299 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: Creating a function to create a checkpoint and exposing it to EmbeddingRocksDBWrapper Reviewed By: duduyi2013 Differential Revision: D75489841
1 parent 3ec6f53 commit b2a33d1

File tree

4 files changed

+47
-0
lines changed

4 files changed

+47
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,13 @@ def purge(self) -> None:
12251225
self.emb_module.lxu_cache_weights.zero_()
12261226
self.emb_module.lxu_cache_state.fill_(-1)
12271227

1228+
# Todo: [Raahul46]: Add a intermediate parent class between embedding and kv to support these functions
1229+
def create_rocksdb_hard_link_snapshot(self) -> None:
1230+
"""
1231+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
1232+
"""
1233+
self.emb_module.create_rocksdb_hard_link_snapshot()
1234+
12281235
# pyre-ignore [15]
12291236
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
12301237
List[PartiallyMaterializedTensor],
@@ -1525,6 +1532,12 @@ def purge(self) -> None:
15251532
self.emb_module.lxu_cache_weights.zero_()
15261533
self.emb_module.lxu_cache_state.fill_(-1)
15271534

1535+
def create_rocksdb_hard_link_snapshot(self) -> None:
1536+
"""
1537+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
1538+
"""
1539+
self.emb_module.create_rocksdb_hard_link_snapshot()
1540+
15281541
# pyre-ignore [15]
15291542
def split_embedding_weights(
15301543
self, no_snapshot: bool = True, should_flush: bool = False
@@ -2038,6 +2051,12 @@ def purge(self) -> None:
20382051
self.emb_module.lxu_cache_weights.zero_()
20392052
self.emb_module.lxu_cache_state.fill_(-1)
20402053

2054+
def create_rocksdb_hard_link_snapshot(self) -> None:
2055+
"""
2056+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
2057+
"""
2058+
self.emb_module.create_rocksdb_hard_link_snapshot()
2059+
20412060
# pyre-ignore [15]
20422061
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
20432062
List[PartiallyMaterializedTensor],

torchrec/distributed/embedding.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,15 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
15451545
else self._embedding_dim
15461546
)
15471547

1548+
def create_rocksdb_hard_link_snapshot(self) -> None:
1549+
for lookup in self._lookups:
1550+
while isinstance(lookup, DistributedDataParallel):
1551+
lookup = lookup.module
1552+
if hasattr(lookup, "create_rocksdb_hard_link_snapshot") and callable(
1553+
lookup.create_rocksdb_hard_link_snapshot()
1554+
):
1555+
lookup.create_rocksdb_hard_link_snapshot()
1556+
15481557
@property
15491558
def fused_optimizer(self) -> KeyedOptimizer:
15501559
return self._optim

torchrec/distributed/embedding_lookup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ def flush(self) -> None:
399399
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
400400
emb_module.flush()
401401

402+
def create_rocksdb_hard_link_snapshot(self) -> None:
403+
for emb_module in self._emb_modules:
404+
if isinstance(emb_module, KeyValueEmbedding):
405+
emb_module.create_rocksdb_hard_link_snapshot()
406+
402407
def purge(self) -> None:
403408
for emb_module in self._emb_modules:
404409
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
@@ -723,6 +728,11 @@ def flush(self) -> None:
723728
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
724729
emb_module.flush()
725730

731+
def create_rocksdb_hard_link_snapshot(self) -> None:
732+
for emb_module in self._emb_modules:
733+
if isinstance(emb_module, KeyValueEmbedding):
734+
emb_module.create_rocksdb_hard_link_snapshot()
735+
726736
def purge(self) -> None:
727737
for emb_module in self._emb_modules:
728738
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.

torchrec/distributed/embeddingbag.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,15 @@ def update_shards(
16381638
update_module_sharding_plan(self, changed_sharding_params)
16391639
return
16401640

1641+
def create_rocksdb_hard_link_snapshot(self) -> None:
1642+
for lookup in self._lookups:
1643+
while isinstance(lookup, DistributedDataParallel):
1644+
lookup = lookup.module
1645+
if hasattr(lookup, "create_rocksdb_hard_link_snapshot") and callable(
1646+
lookup.create_rocksdb_hard_link_snapshot()
1647+
):
1648+
lookup.create_rocksdb_hard_link_snapshot()
1649+
16411650
@property
16421651
def fused_optimizer(self) -> KeyedOptimizer:
16431652
return self._optim

0 commit comments

Comments
 (0)