Skip to content

Commit 71ff1a3

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Adding function to create a snapshot and exposing it from EmbeddingRocksDBWrapper (#3024)
Summary: X-link: pytorch/FBGEMM#4223 X-link: facebookresearch/FBGEMM#1299 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: Creating a function to create a checkpoint and exposing it to EmbeddingRocksDBWrapper Reviewed By: duduyi2013 Differential Revision: D75489841
1 parent e7dc586 commit 71ff1a3

File tree

4 files changed

+46
-0
lines changed

4 files changed

+46
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,6 +1225,12 @@ def purge(self) -> None:
12251225
self.emb_module.lxu_cache_weights.zero_()
12261226
self.emb_module.lxu_cache_state.fill_(-1)
12271227

1228+
def create_rocksdb_hard_link_snapshot(self) -> None:
1229+
"""
1230+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
1231+
"""
1232+
self.emb_module.create_rocksdb_hard_link_snapshot()
1233+
12281234
# pyre-ignore [15]
12291235
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
12301236
List[PartiallyMaterializedTensor],
@@ -1520,6 +1526,12 @@ def purge(self) -> None:
15201526
self.emb_module.lxu_cache_weights.zero_()
15211527
self.emb_module.lxu_cache_state.fill_(-1)
15221528

1529+
def create_rocksdb_hard_link_snapshot(self) -> None:
1530+
"""
1531+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
1532+
"""
1533+
self.emb_module.create_rocksdb_hard_link_snapshot()
1534+
15231535
# pyre-ignore [15]
15241536
def split_embedding_weights(
15251537
self, no_snapshot: bool = True, should_flush: bool = True
@@ -2033,6 +2045,12 @@ def purge(self) -> None:
20332045
self.emb_module.lxu_cache_weights.zero_()
20342046
self.emb_module.lxu_cache_state.fill_(-1)
20352047

2048+
def create_rocksdb_hard_link_snapshot(self) -> None:
2049+
"""
2050+
Create a RocksDB checkpoint. This is needed before we call state_dict() for publish.
2051+
"""
2052+
self.emb_module.create_rocksdb_hard_link_snapshot()
2053+
20362054
# pyre-ignore [15]
20372055
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
20382056
List[PartiallyMaterializedTensor],

torchrec/distributed/embedding.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1533,6 +1533,15 @@ def _embedding_dim_for_sharding_type(self, sharding_type: str) -> int:
15331533
else self._embedding_dim
15341534
)
15351535

1536+
def create_rocksdb_hard_link_snapshot(self) -> None:
1537+
for lookup in self._lookups:
1538+
while isinstance(lookup, DistributedDataParallel):
1539+
lookup = lookup.module
1540+
if hasattr(lookup, "create_rocksdb_hard_link_snapshot") and callable(
1541+
lookup.create_rocksdb_hard_link_snapshot()
1542+
):
1543+
lookup.create_rocksdb_hard_link_snapshot()
1544+
15361545
@property
15371546
def fused_optimizer(self) -> KeyedOptimizer:
15381547
return self._optim

torchrec/distributed/embedding_lookup.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,11 @@ def flush(self) -> None:
399399
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
400400
emb_module.flush()
401401

402+
def create_rocksdb_hard_link_snapshot(self) -> None:
403+
for emb_module in self._emb_modules:
404+
if isinstance(emb_module, KeyValueEmbedding):
405+
emb_module.create_rocksdb_hard_link_snapshot()
406+
402407
def purge(self) -> None:
403408
for emb_module in self._emb_modules:
404409
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
@@ -723,6 +728,11 @@ def flush(self) -> None:
723728
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
724729
emb_module.flush()
725730

731+
def create_rocksdb_hard_link_snapshot(self) -> None:
732+
for emb_module in self._emb_modules:
733+
if isinstance(emb_module, KeyValueEmbedding):
734+
emb_module.create_rocksdb_hard_link_snapshot()
735+
726736
def purge(self) -> None:
727737
for emb_module in self._emb_modules:
728738
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.

torchrec/distributed/embeddingbag.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1638,6 +1638,15 @@ def update_shards(
16381638
update_module_sharding_plan(self, changed_sharding_params)
16391639
return
16401640

1641+
def create_rocksdb_hard_link_snapshot(self) -> None:
1642+
for lookup in self._lookups:
1643+
while isinstance(lookup, DistributedDataParallel):
1644+
lookup = lookup.module
1645+
if hasattr(lookup, "create_rocksdb_hard_link_snapshot") and callable(
1646+
lookup.create_rocksdb_hard_link_snapshot()
1647+
):
1648+
lookup.create_rocksdb_hard_link_snapshot()
1649+
16411650
@property
16421651
def fused_optimizer(self) -> KeyedOptimizer:
16431652
return self._optim

0 commit comments

Comments
 (0)