Skip to content

Adding helper function for enabling RocksDB Checkpoint #4213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@

#include <ATen/ATen.h>
#include <folly/hash/Hash.h>
#include <glog/logging.h>
#include <stddef.h>
#include <stdint.h>
#include <filesystem>
#include <optional>

/// @defgroup embedding-ssd Embedding SSD Operators
///

namespace kv_db_utils {

#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
#endif

/// @ingroup embedding-ssd
///
/// @brief hash function used for SSD L2 cache and rocksdb sharding algorithm
Expand Down Expand Up @@ -65,4 +70,94 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
std::optional<int64_t> bucket_size,
std::optional<int64_t> total_num_buckets);

/// @ingroup embedding-ssd
///
/// @brief default way to generate rocksdb path based on a user provided
/// base_path the file hierarchy will be
/// <base_path><ssd_idx>/<tbe_uuid> for default SSD mount
/// <base_path>/<tbe_uuid> for user provided base path
///
/// @param base_path the base path for all the rocksdb shards tied to one
/// TBE/EmbeddingRocksDB
/// @param db_shard_id the rocksdb shard index, this is used to determine which
/// SSD to use
/// @param tbe_uuid unique identifier per TBE at the lifetime of a training job
/// @param default_path whether the base_path is default SSD mount or
/// user-provided
///
/// @return the base path to that rocksdb shard
inline std::string get_rocksdb_path(
const std::string& base_path,
int db_shard_id,
const std::string& tbe_uuid,
bool default_path) {
if (default_path) {
int ssd_drive_idx = db_shard_id % num_ssd_drives;
std::string ssd_idx_tbe_id_str =
std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid;
return base_path + ssd_idx_tbe_id_str;
} else {
return base_path + std::string("/") + tbe_uuid;
}
}

/// @ingroup embedding-ssd
///
/// @brief generate rocksdb shard path, based on rocksdb_path
/// the file hierarchy will be
/// <rocksdb_shard_path>/shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the rocksdb shard path
inline std::string get_rocksdb_shard_path(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/shard_") + std::to_string(db_shard_id);
}

/// @ingroup embedding-ssd
///
/// @brief generate a directory to hold rocksdb checkpoint for a particular
/// rocksdb shard path the file hierarchy will be
/// <rocksdb_shard_path>/checkpoint_shard_<db_shard>
///
/// @param db_shard_id the rocksdb shard index
/// @param rocksdb_path the base path for rocksdb shard
///
/// @return the directory that holds rocksdb checkpoints for one rocksdb shard
inline std::string get_rocksdb_checkpoint_dir(
int db_shard_id,
const std::string& rocksdb_path) {
return rocksdb_path + std::string("/checkpoint_shard_") +
std::to_string(db_shard_id);
}

inline void create_dir(const std::string& dir_path) {
try {
std::filesystem::path fs_path(dir_path);
bool res = std::filesystem::create_directories(fs_path);
if (!res) {
LOG(ERROR) << "dir: " << dir_path << " already exists";
}
} catch (const std::exception& e) {
LOG(ERROR) << "Error creating directory: " << e.what();
}
}

inline void remove_dir(const std::string& path) {
if (std::filesystem::exists(path)) {
try {
if (std::filesystem::is_directory(path)) {
std::filesystem::remove_all(path);
} else {
std::filesystem::remove(path);
}
} catch (const std::filesystem::filesystem_error& e) {
LOG(ERROR) << "Error removing path: " << path
<< ", exception:" << e.what();
}
}
}
}; // namespace kv_db_utils
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,7 @@ namespace ssd {

using namespace at;

#ifdef FBGEMM_FBCODE
constexpr size_t num_ssd_drives = 8;
const std::string ssd_mount_point = "/data00_nvidia";
const size_t base_port = 136000;
#endif

// mem usage propertiese
// -- block cache usage
Expand Down Expand Up @@ -322,23 +318,22 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
auto db_monitor_options = facebook::fb_rocksdb::DBMonitorOptions();
db_monitor_options.fb303Prefix = "tbe_metrics";

std::string tbe_uuid = "";
tbe_uuid_ = facebook::strings::generateUUID();
use_default_ssd_path_ = !use_passed_in_path;
if (!use_passed_in_path) {
path = ssd_mount_point;
tbe_uuid = facebook::strings::generateUUID();
path_ = std::move(ssd_mount_point);
} else {
path_ = std::move(path);
}
std::string all_shards_path;
#endif
for (auto i = 0; i < num_shards; ++i) {
#ifdef FBGEMM_FBCODE
int ssd_drive_idx = i % num_ssd_drives;
std::string ssd_idx_tbe_id_str = "";
if (!use_passed_in_path) {
ssd_idx_tbe_id_str =
std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid;
}
auto shard_path =
path + ssd_idx_tbe_id_str + std::string("_shard") + std::to_string(i);
used_path += shard_path + ", ";
auto rocksdb_path = kv_db_utils::get_rocksdb_path(
path_, i, tbe_uuid_, !use_passed_in_path);
auto shard_path = kv_db_utils::get_rocksdb_shard_path(i, rocksdb_path);
kv_db_utils::create_dir(shard_path);
all_shards_path += shard_path + ", ";
#else
auto shard_path = path + std::string("/shard_") + std::to_string(i);
#endif
Expand Down Expand Up @@ -371,7 +366,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
dbs_.emplace_back(db);
}
#ifdef FBGEMM_FBCODE
LOG(INFO) << "TBE actual used_path: " << used_path;
LOG(INFO) << "TBE uuid: " << tbe_uuid_
<< ", rocksdb shards paths: " << all_shards_path;
#endif
}

Expand Down Expand Up @@ -1152,6 +1148,9 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
int64_t elem_size_;
std::vector<int64_t> sub_table_dims_;
std::vector<int64_t> sub_table_hash_cumsum_;
std::string tbe_uuid_;
std::string path_;
bool use_default_ssd_path_;
}; // class EmbeddingRocksDB

} // namespace ssd
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ TEST(RocksDbEmbeddingCacheTest, TestPutAndGet) {
-0.01, // uniform_init_lower,
0.01, // uniform_init_upper,
32, // row_storage_bitwidth = 32,
0 // cache_size = 0
0, // cache_size = 0
true // use_passed_in_path
);

auto write_indices =
Expand Down
Loading