Skip to content

Commit 50d8eaa

Browse files
raahul46facebook-github-bot
raahul46
authored andcommitted
Temporary Commit at 5/26/2025, 10:03:16 PM
Differential Revision: D75489827
1 parent faa8eff commit 50d8eaa

File tree

2 files changed

+133
-1
lines changed

2 files changed

+133
-1
lines changed

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
#include "embedding_rocksdb_wrapper.h"
1919
#include "fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h"
2020
#include "fbgemm_gpu/utils/ops_utils.h"
21-
21+
#include "rocksdb/utilities/checkpoint.h"
2222
using namespace at;
2323
using namespace ssd;
2424
using namespace kv_mem;
@@ -293,6 +293,37 @@ snapshot_ptr_t SnapshotHandle::get_snapshot_for_shard(size_t shard) const {
293293
return shard_snapshots_[shard];
294294
}
295295

296+
CheckpointHandle::CheckpointHandle(
297+
EmbeddingRocksDB* db,
298+
const std::string& tbe_uuid,
299+
const std::string& ckpt_uuid,
300+
const std::string& base_path,
301+
bool use_default_ssd_path)
302+
: db_(db), ckpt_uuid_(ckpt_uuid) {
303+
auto num_shards = db->num_shards();
304+
CHECK_GT(num_shards, 0);
305+
shard_checkpoints_.reserve(num_shards);
306+
for (auto shard = 0; shard < num_shards; ++shard) {
307+
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
308+
base_path, shard, tbe_uuid, use_default_ssd_path);
309+
auto checkpoint_shard_dir =
310+
kv_db_utils::get_rocksdb_checkpoint_dir(shard, rocksdb_path);
311+
kv_db_utils::create_dir(checkpoint_shard_dir);
312+
rocksdb::Checkpoint* checkpoint = nullptr;
313+
rocksdb::Status s =
314+
rocksdb::Checkpoint::Create(db->dbs_[shard].get(), &checkpoint);
315+
CHECK(s.ok()) << "ERROR: Checkpoint init for tbe_uuid " << tbe_uuid
316+
<< ", db shard " << shard << " failed, " << s.code() << ", "
317+
<< s.ToString();
318+
std::string checkpoint_shard_path = checkpoint_shard_dir + "/" + ckpt_uuid_;
319+
s = checkpoint->CreateCheckpoint(checkpoint_shard_path);
320+
CHECK(s.ok()) << "ERROR: Checkpoint creation for tbe_uuid " << tbe_uuid
321+
<< ", db shard " << shard << " failed, " << s.code() << ", "
322+
<< s.ToString();
323+
shard_checkpoints_.push_back(checkpoint_shard_path);
324+
}
325+
}
326+
296327
EmbeddingSnapshotHandleWrapper::EmbeddingSnapshotHandleWrapper(
297328
const SnapshotHandle* handle,
298329
std::shared_ptr<EmbeddingRocksDB> db)

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ class SnapshotHandle {
6565
std::vector<snapshot_ptr_t> shard_snapshots_;
6666
}; // class SnapshotHandle
6767

68+
using checkpoint_path = std::string;
69+
// @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
70+
class CheckpointHandle {
71+
public:
72+
explicit CheckpointHandle(
73+
EmbeddingRocksDB* db,
74+
const std::string& tbe_uuid,
75+
const std::string& ckpt_uuid,
76+
const std::string& base_path,
77+
bool use_default_ssd_path);
78+
79+
private:
80+
friend class EmbeddingRocksDB;
81+
82+
EmbeddingRocksDB* db_;
83+
std::string ckpt_uuid_;
84+
std::vector<checkpoint_path> shard_checkpoints_;
85+
}; // class CheckpointHandle
86+
6887
/// @ingroup embedding-ssd
6988
///
7089
/// @brief An implementation of EmbeddingKVDB for RocksDB
@@ -488,6 +507,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
488507
return snapshots_.find(snapshot_handle) != snapshots_.end();
489508
}
490509

510+
bool is_valid_checkpoint(const std::string ckpt_uuid) const {
511+
return checkpoints_.find(ckpt_uuid) != checkpoints_.end();
512+
}
513+
491514
int64_t get_snapshot_count() const {
492515
return snapshots_.size();
493516
}
@@ -505,6 +528,57 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
505528
return handlePtr;
506529
}
507530

531+
std::string create_checkpoint(int64_t global_step) {
532+
const auto num_ckpts = checkpoints_.size();
533+
if (num_ckpts > 0) {
534+
std::cerr << "WARNING: rocksdb create_checkpoint found " << num_ckpts
535+
<< " other checkpoints_" << std::endl;
536+
}
537+
538+
// If the global step already has a checkpoint handler registered, at the
539+
// time create_checkpoint is call, we assume the prev ckpt hanlder has
540+
// fullfilled its job already, thus it is ok to replace it with the new rdb
541+
// checkpoint for next use cases within the same global step
542+
if (global_step_to_ckpt_uuid_.find(global_step) !=
543+
global_step_to_ckpt_uuid_.end()) {
544+
LOG(WARNING)
545+
<< "multiple rdb checkpoint in one global step are being created, "
546+
"removing the prev rdb ckpt, please make sure it has fullfilled "
547+
"its use case, e.g. checkpoint and publish";
548+
}
549+
auto ckpt_uuid = facebook::strings::generateUUID();
550+
auto handle = std::make_unique<CheckpointHandle>(
551+
this, tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
552+
checkpoints_[ckpt_uuid] = std::move(handle);
553+
global_step_to_ckpt_uuid_[global_step] = ckpt_uuid;
554+
return ckpt_uuid;
555+
}
556+
557+
std::optional<std::string> get_active_checkpoint_uuid(int64_t global_step) {
558+
if (global_step_to_ckpt_uuid_.find(global_step) !=
559+
global_step_to_ckpt_uuid_.end()) {
560+
return std::make_optional<std::string>(
561+
global_step_to_ckpt_uuid_[global_step]);
562+
}
563+
return std::nullopt;
564+
}
565+
566+
void release_checkpoint(const std::string ckpt_uuid) {
567+
CHECK_EQ(is_valid_checkpoint(ckpt_uuid), true);
568+
LOG(INFO) << "Checkpoint " << ckpt_uuid << " released";
569+
checkpoints_.erase(ckpt_uuid);
570+
// sweep through global_step_to_ckpt_uuid_, it should be small
571+
int64_t glb_step_to_purge = -1;
572+
for (const auto& [global_step, uuid] : global_step_to_ckpt_uuid_) {
573+
if (ckpt_uuid == uuid) {
574+
glb_step_to_purge = global_step;
575+
break;
576+
}
577+
}
578+
CHECK_NE(glb_step_to_purge, -1) << "There must be a rdb ckpt uuid to purge";
579+
global_step_to_ckpt_uuid_.erase(glb_step_to_purge);
580+
}
581+
508582
void release_snapshot(const SnapshotHandle* snapshot_handle) {
509583
CHECK(is_valid_snapshot(snapshot_handle));
510584
LOG(INFO) << "Snapshot " << snapshot_handle << " released";
@@ -1117,6 +1191,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
11171191
}
11181192

11191193
friend class SnapshotHandle;
1194+
friend class CheckpointHandle;
11201195

11211196
std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
11221197
std::vector<std::unique_ptr<Initializer>> initializers_;
@@ -1149,6 +1224,32 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
11491224
std::string tbe_uuid_;
11501225
std::string path_;
11511226
bool use_default_ssd_path_;
1227+
1228+
// rocksdb checkpoint is used to create an on disk database to support
1229+
// cross process read-only access
1230+
std::unordered_map<std::string, std::unique_ptr<CheckpointHandle>>
1231+
checkpoints_;
1232+
// this is used for KVTensor rdb checkpoint linking by global
1233+
// step, reasons are shown below
1234+
// 1. rdb checkpoint is created at most twice, for publish and checkpoint
1235+
// separately, if they happen on the same train iteration. We can not create
1236+
// rdb checkpoint freely because the lifecycle of rdb checkpoint is controlled
1237+
// on the component side
1238+
//
1239+
// 2. publish tends to call state_dict() multiple times to get model FQNs, and
1240+
// it is not recommended to modify state_dict signature, thus there is no way
1241+
// for the TBE backend to tell which state_dict calls is for weight accessing.
1242+
// state_dict() returns KVTensorWrapper to the trainer side, which will be
1243+
// consumed by the downstream componenet, e.g. checkpoint and publish, we want
1244+
// to link the rdb checkpoint with KVTensorWrapper
1245+
//
1246+
// 3. therefore we need to way to linked the created rdb checkpoint with
1247+
// KVTensorWrapper, and potentially we could have multiple rdb
1248+
// checkpoint from different iteration(this is less likely, especially if we
1249+
// don't copy KVTensorWrapper to a separate python thread which extends the
1250+
// rdb checkpoint handler lifetime). But just in case, we created a global
1251+
// step -> rdb checkpoint mapping
1252+
std::unordered_map<int64_t, std::string> global_step_to_ckpt_uuid_;
11521253
}; // class EmbeddingRocksDB
11531254

11541255
} // namespace ssd

0 commit comments

Comments
 (0)