@@ -65,6 +65,25 @@ class SnapshotHandle {
65
65
std::vector<snapshot_ptr_t > shard_snapshots_;
66
66
}; // class SnapshotHandle
67
67
68
+ using checkpoint_path = std::string;
69
+ // @lint-ignore CLANGTIDY cppcoreguidelines-special-member-functions
70
+ class CheckpointHandle {
71
+ public:
72
+ explicit CheckpointHandle (
73
+ EmbeddingRocksDB* db,
74
+ const std::string& tbe_uuid,
75
+ const std::string& ckpt_uuid,
76
+ const std::string& base_path,
77
+ bool use_default_ssd_path);
78
+
79
+ private:
80
+ friend class EmbeddingRocksDB ;
81
+
82
+ EmbeddingRocksDB* db_;
83
+ std::string ckpt_uuid_;
84
+ std::vector<checkpoint_path> shard_checkpoints_;
85
+ }; // class CheckpointHandle
86
+
68
87
// / @ingroup embedding-ssd
69
88
// /
70
89
// / @brief An implementation of EmbeddingKVDB for RocksDB
@@ -488,6 +507,10 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
488
507
return snapshots_.find (snapshot_handle) != snapshots_.end ();
489
508
}
490
509
510
+ bool is_valid_checkpoint (const std::string ckpt_uuid) const {
511
+ return checkpoints_.find (ckpt_uuid) != checkpoints_.end ();
512
+ }
513
+
491
514
int64_t get_snapshot_count () const {
492
515
return snapshots_.size ();
493
516
}
@@ -505,6 +528,57 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
505
528
return handlePtr;
506
529
}
507
530
531
+ std::string create_checkpoint (int64_t global_step) {
532
+ const auto num_ckpts = checkpoints_.size ();
533
+ if (num_ckpts > 0 ) {
534
+ std::cerr << " WARNING: rocksdb create_checkpoint found " << num_ckpts
535
+ << " other checkpoints_" << std::endl;
536
+ }
537
+
538
+ // If the global step already has a checkpoint handler registered, at the
539
+ // time create_checkpoint is call, we assume the prev ckpt hanlder has
540
+ // fullfilled its job already, thus it is ok to replace it with the new rdb
541
+ // checkpoint for next use cases within the same global step
542
+ if (global_step_to_ckpt_uuid_.find (global_step) !=
543
+ global_step_to_ckpt_uuid_.end ()) {
544
+ LOG (WARNING)
545
+ << " multiple rdb checkpoint in one global step are being created, "
546
+ " removing the prev rdb ckpt, please make sure it has fullfilled "
547
+ " its use case, e.g. checkpoint and publish" ;
548
+ }
549
+ auto ckpt_uuid = facebook::strings::generateUUID ();
550
+ auto handle = std::make_unique<CheckpointHandle>(
551
+ this , tbe_uuid_, ckpt_uuid, path_, use_default_ssd_path_);
552
+ checkpoints_[ckpt_uuid] = std::move (handle);
553
+ global_step_to_ckpt_uuid_[global_step] = ckpt_uuid;
554
+ return ckpt_uuid;
555
+ }
556
+
557
+ std::optional<std::string> get_active_checkpoint_uuid (int64_t global_step) {
558
+ if (global_step_to_ckpt_uuid_.find (global_step) !=
559
+ global_step_to_ckpt_uuid_.end ()) {
560
+ return std::make_optional<std::string>(
561
+ global_step_to_ckpt_uuid_[global_step]);
562
+ }
563
+ return std::nullopt;
564
+ }
565
+
566
+ void release_checkpoint (const std::string ckpt_uuid) {
567
+ CHECK_EQ (is_valid_checkpoint (ckpt_uuid), true );
568
+ LOG (INFO) << " Checkpoint " << ckpt_uuid << " released" ;
569
+ checkpoints_.erase (ckpt_uuid);
570
+ // sweep through global_step_to_ckpt_uuid_, it should be small
571
+ int64_t glb_step_to_purge = -1 ;
572
+ for (const auto & [global_step, uuid] : global_step_to_ckpt_uuid_) {
573
+ if (ckpt_uuid == uuid) {
574
+ glb_step_to_purge = global_step;
575
+ break ;
576
+ }
577
+ }
578
+ CHECK_NE (glb_step_to_purge, -1 ) << " There must be a rdb ckpt uuid to purge" ;
579
+ global_step_to_ckpt_uuid_.erase (glb_step_to_purge);
580
+ }
581
+
508
582
void release_snapshot (const SnapshotHandle* snapshot_handle) {
509
583
CHECK (is_valid_snapshot (snapshot_handle));
510
584
LOG (INFO) << " Snapshot " << snapshot_handle << " released" ;
@@ -1117,6 +1191,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
1117
1191
}
1118
1192
1119
1193
friend class SnapshotHandle ;
1194
+ friend class CheckpointHandle ;
1120
1195
1121
1196
std::vector<std::unique_ptr<rocksdb::DB>> dbs_;
1122
1197
std::vector<std::unique_ptr<Initializer>> initializers_;
@@ -1149,6 +1224,32 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
1149
1224
std::string tbe_uuid_;
1150
1225
std::string path_;
1151
1226
bool use_default_ssd_path_;
1227
+
1228
+ // rocksdb checkpoint is used to create an on disk database to support
1229
+ // cross process read-only access
1230
+ std::unordered_map<std::string, std::unique_ptr<CheckpointHandle>>
1231
+ checkpoints_;
1232
+ // this is used for KVTensor rdb checkpoint linking by global
1233
+ // step, reasons are shown below
1234
+ // 1. rdb checkpoint is created at most twice, for publish and checkpoint
1235
+ // separately, if they happen on the same train iteration. We can not create
1236
+ // rdb checkpoint freely because the lifecycle of rdb checkpoint is controlled
1237
+ // on the component side
1238
+ //
1239
+ // 2. publish tends to call state_dict() multiple times to get model FQNs, and
1240
+ // it is not recommended to modify state_dict signature, thus there is no way
1241
+ // for the TBE backend to tell which state_dict calls is for weight accessing.
1242
+ // state_dict() returns KVTensorWrapper to the trainer side, which will be
1243
+ // consumed by the downstream componenet, e.g. checkpoint and publish, we want
1244
+ // to link the rdb checkpoint with KVTensorWrapper
1245
+ //
1246
+ // 3. therefore we need to way to linked the created rdb checkpoint with
1247
+ // KVTensorWrapper, and potentially we could have multiple rdb
1248
+ // checkpoint from different iteration(this is less likely, especially if we
1249
+ // don't copy KVTensorWrapper to a separate python thread which extends the
1250
+ // rdb checkpoint handler lifetime). But just in case, we created a global
1251
+ // step -> rdb checkpoint mapping
1252
+ std::unordered_map<int64_t , std::string> global_step_to_ckpt_uuid_;
1152
1253
}; // class EmbeddingRocksDB
1153
1254
1154
1255
} // namespace ssd
0 commit comments