|
10 | 10 |
|
11 | 11 | #include <ATen/ATen.h>
|
12 | 12 | #include <folly/hash/Hash.h>
|
| 13 | +#include <glog/logging.h> |
13 | 14 | #include <stddef.h>
|
14 | 15 | #include <stdint.h>
|
| 16 | +#include <filesystem> |
15 | 17 | #include <optional>
|
16 |
| - |
17 | 18 | /// @defgroup embedding-ssd Embedding SSD Operators
|
18 | 19 | ///
|
19 | 20 |
|
20 | 21 | namespace kv_db_utils {
|
21 | 22 |
|
| 23 | +#ifdef FBGEMM_FBCODE |
| 24 | +constexpr size_t num_ssd_drives = 8; |
| 25 | +#endif |
| 26 | + |
22 | 27 | /// @ingroup embedding-ssd
|
23 | 28 | ///
|
24 | 29 | /// @brief hash function used for SSD L2 cache and rocksdb sharding algorithm
|
@@ -65,4 +70,94 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
|
65 | 70 | std::optional<int64_t> bucket_size,
|
66 | 71 | std::optional<int64_t> total_num_buckets);
|
67 | 72 |
|
| 73 | +/// @ingroup embedding-ssd |
| 74 | +/// |
| 75 | +/// @brief default way to generate rocksdb path based on a user provided |
| 76 | +/// base_path the file hierarchy will be |
| 77 | +/// <base_path><ssd_idx>/<tbe_uuid> for default SSD mount |
| 78 | +/// <base_path>/<tbe_uuid> for user provided base path |
| 79 | +/// |
| 80 | +/// @param base_path the base path for all the rocksdb shards tied to one |
| 81 | +/// TBE/EmbeddingRocksDB |
| 82 | +/// @param db_shard_id the rocksdb shard index, this is used to determine which |
| 83 | +/// SSD to use |
| 84 | +/// @param tbe_uuid unique identifier per TBE at the lifetime of a training job |
| 85 | +/// @param default_path whether the base_path is default SSD mount or |
| 86 | +/// user-provided |
| 87 | +/// |
| 88 | +/// @return the base path to that rocksdb shard |
| 89 | +inline std::string get_rocksdb_path( |
| 90 | + const std::string& base_path, |
| 91 | + int db_shard_id, |
| 92 | + const std::string& tbe_uuid, |
| 93 | + bool default_path) { |
| 94 | + if (default_path) { |
| 95 | + int ssd_drive_idx = db_shard_id % num_ssd_drives; |
| 96 | + std::string ssd_idx_tbe_id_str = |
| 97 | + std::to_string(ssd_drive_idx) + std::string("/") + tbe_uuid; |
| 98 | + return base_path + ssd_idx_tbe_id_str; |
| 99 | + } else { |
| 100 | + return base_path + std::string("/") + tbe_uuid; |
| 101 | + } |
| 102 | +} |
| 103 | + |
| 104 | +/// @ingroup embedding-ssd |
| 105 | +/// |
| 106 | +/// @brief generate rocksdb shard path, based on rocksdb_path |
| 107 | +/// the file hierarchy will be |
| 108 | +/// <rocksdb_shard_path>/shard_<db_shard> |
| 109 | +/// |
| 110 | +/// @param db_shard_id the rocksdb shard index |
| 111 | +/// @param rocksdb_path the base path for rocksdb shard |
| 112 | +/// |
| 113 | +/// @return the rocksdb shard path |
| 114 | +inline std::string get_rocksdb_shard_path( |
| 115 | + int db_shard_id, |
| 116 | + const std::string& rocksdb_path) { |
| 117 | + return rocksdb_path + std::string("/shard_") + std::to_string(db_shard_id); |
| 118 | +} |
| 119 | + |
| 120 | +/// @ingroup embedding-ssd |
| 121 | +/// |
| 122 | +/// @brief generate a directory to hold rocksdb checkpoint for a particular |
| 123 | +/// rocksdb shard path the file hierarchy will be |
| 124 | +/// <rocksdb_shard_path>/checkpoint_shard_<db_shard> |
| 125 | +/// |
| 126 | +/// @param db_shard_id the rocksdb shard index |
| 127 | +/// @param rocksdb_path the base path for rocksdb shard |
| 128 | +/// |
| 129 | +/// @return the directory that holds rocksdb checkpoints for one rocksdb shard |
| 130 | +inline std::string get_rocksdb_checkpoint_dir( |
| 131 | + int db_shard_id, |
| 132 | + const std::string& rocksdb_path) { |
| 133 | + return rocksdb_path + std::string("/checkpoint_shard_") + |
| 134 | + std::to_string(db_shard_id); |
| 135 | +} |
| 136 | + |
| 137 | +inline void create_dir(const std::string& dir_path) { |
| 138 | + try { |
| 139 | + std::filesystem::path fs_path(dir_path); |
| 140 | + bool res = std::filesystem::create_directories(fs_path); |
| 141 | + if (!res) { |
| 142 | + LOG(ERROR) << "dir: " << dir_path << " already exists"; |
| 143 | + } |
| 144 | + } catch (const std::exception& e) { |
| 145 | + LOG(ERROR) << "Error creating directory: " << e.what(); |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +inline void remove_dir(const std::string& path) { |
| 150 | + if (std::filesystem::exists(path)) { |
| 151 | + try { |
| 152 | + if (std::filesystem::is_directory(path)) { |
| 153 | + std::filesystem::remove_all(path); |
| 154 | + } else { |
| 155 | + std::filesystem::remove(path); |
| 156 | + } |
| 157 | + } catch (const std::filesystem::filesystem_error& e) { |
| 158 | + LOG(ERROR) << "Error removing path: " << path |
| 159 | + << ", exception:" << e.what(); |
| 160 | + } |
| 161 | + } |
| 162 | +} |
68 | 163 | }; // namespace kv_db_utils
|
0 commit comments