Skip to content

Commit d82f927

Browse files
Raahul Kalyaan Jakkafacebook-github-bot
Raahul Kalyaan Jakka
authored andcommitted
Adding function to create a snapshot and exposing it from EmbeddingRocksDBWrapper (pytorch#4223)
Summary: X-link: pytorch/torchrec#3024 Pull Request resolved: pytorch#4223 X-link: facebookresearch/FBGEMM#1299 Design doc: https://docs.google.com/document/d/149LdAEHOLP7ei4hwVVkAFXGa4N9uLs1J7efxfBZp3dY/edit?tab=t.0#heading=h.49t3yfaqmt54 Context: We are enabling the usage of rocksDB checkpoint feature in KVTensorWrapper. This allows us to create checkpoints of the embedding tables in SSD. Later, these checkpoints are used by the checkpointing component to create a checkpoint and upload it it to the manifold In this diff: Creating a function to create a checkpoint and exposing it to EmbeddingRocksDBWrapper Reviewed By: duduyi2013 Differential Revision: D75489841
1 parent 50d8eaa commit d82f927

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,6 +2490,12 @@ def flush(self, force: bool = False) -> None:
24902490
self.ssd_db.flush()
24912491
self.last_flush_step = self.step
24922492

2493+
def create_rocksdb_hard_link_snapshot(self) -> None:
2494+
"""
2495+
Create a rocksdb hard link snapshot to provide cross procs access to the underlying data
2496+
"""
2497+
self.ssd_db.create_rocksdb_hard_link_snapshot(self.step)
2498+
24932499
def prepare_inputs(
24942500
self,
24952501
indices: Tensor,

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
176176
return impl_->get_snapshot_count();
177177
}
178178

179+
void create_rocksdb_hard_link_snapshot(int64_t global_step) {
180+
impl_->create_checkpoint(global_step);
181+
}
182+
179183
private:
180184
friend class KVTensorWrapper;
181185

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,10 @@ static auto embedding_rocks_db_wrapper =
611611
.def("get_snapshot_count", &EmbeddingRocksDBWrapper::get_snapshot_count)
612612
.def(
613613
"get_keys_in_range_by_snapshot",
614-
&EmbeddingRocksDBWrapper::get_keys_in_range_by_snapshot);
614+
&EmbeddingRocksDBWrapper::get_keys_in_range_by_snapshot)
615+
.def(
616+
"create_rocksdb_hard_link_snapshot",
617+
&EmbeddingRocksDBWrapper::create_rocksdb_hard_link_snapshot);
615618

616619
static auto dram_kv_embedding_cache_wrapper =
617620
torch::class_<DramKVEmbeddingCacheWrapper>(

0 commit comments

Comments
 (0)