Skip to content

Commit 0a963db

Browse files
chouximeta-codesync[bot]
authored andcommitted
Thread raw embedding streamer to dram_kv_embedding_cache (#5432)
Summary: Pull Request resolved: #5432 X-link: https://github.com/facebookresearch/FBGEMM/pull/2404 Thread RES (Raw Embedding Streaming) parameters through the DRAM KV embedding cache constructor chain and pybind to enable streaming for the embedding cache enrichment path. Currently the feature is gated by `enable_raw_embedding_streaming` Key changes: - Thread 6 RES params (DramKVEmbeddingCache -> wrapper -> pybind -> Python) - Make raw_embedding_streamer_ protected for subclass access Reviewed By: FriedCosey Differential Revision: D94431329 fbshipit-source-id: 7af10855718fe24b77de1d66d1437d681e47bd48
1 parent 1d3c78e commit 0a963db

File tree

5 files changed

+56
-6
lines changed

5 files changed

+56
-6
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,12 @@ def __init__(
783783
self.backend_return_whole_row, # backend_return_whole_row
784784
False, # enable_async_update
785785
self._embedding_cache_mode, # disable_random_init
786+
self.enable_raw_embedding_streaming,
787+
self.res_params.res_store_shards,
788+
self.res_params.res_server_port,
789+
self.res_params.table_names,
790+
self.res_params.table_offsets,
791+
self.res_params.table_sizes,
786792
)
787793
else:
788794
raise AssertionError(f"Invalid backend type {self.backend_type}")

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache.h

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,26 @@ class DramKVEmbeddingCache : public kv_db::EmbeddingKVDB {
110110
std::optional<at::Tensor> table_dims = std::nullopt,
111111
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
112112
bool is_training = true,
113-
bool disable_random_init = false)
113+
bool disable_random_init = false,
114+
bool enable_raw_embedding_streaming = false,
115+
int64_t res_store_shards = 0,
116+
int64_t res_server_port = 0,
117+
std::vector<std::string> table_names = {},
118+
std::vector<int64_t> table_offsets = {},
119+
std::vector<int64_t> table_sizes = {})
114120
: kv_db::EmbeddingKVDB(
115121
num_shards,
116122
max_D,
117123
0, // l2_cache_size_gb =0 to disable l2 cache
118124
0, // tbe_unqiue_id
119125
2, // ele_size_bytes
120-
enable_async_update),
126+
enable_async_update,
127+
enable_raw_embedding_streaming,
128+
res_store_shards,
129+
res_server_port,
130+
std::move(table_names),
131+
std::move(table_offsets),
132+
table_sizes),
121133
max_D_(max_D),
122134
num_shards_(num_shards),
123135
block_size_(FixedBlockPool::calculate_block_size<weight_type>(max_D)),

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
3333
const std::optional<at::Tensor>& hash_size_cumsum = std::nullopt,
3434
bool backend_return_whole_row = false,
3535
bool enable_async_update = false,
36-
bool disable_random_init = false) {
36+
bool disable_random_init = false,
37+
bool enable_raw_embedding_streaming = false,
38+
int64_t res_store_shards = 0,
39+
int64_t res_server_port = 0,
40+
std::vector<std::string> table_names = {},
41+
std::vector<int64_t> table_offsets = {},
42+
std::vector<int64_t> table_sizes = {}) {
3743
if (row_storage_bitwidth == 16) {
3844
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
3945
max_D,
@@ -48,7 +54,13 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
4854
table_dims,
4955
hash_size_cumsum,
5056
true, // is_training
51-
disable_random_init);
57+
disable_random_init,
58+
enable_raw_embedding_streaming,
59+
res_store_shards,
60+
res_server_port,
61+
std::move(table_names),
62+
std::move(table_offsets),
63+
std::move(table_sizes));
5264
} else if (row_storage_bitwidth == 32) {
5365
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<float>>(
5466
max_D,
@@ -63,7 +75,13 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
6375
table_dims,
6476
hash_size_cumsum,
6577
true, // is_training
66-
disable_random_init);
78+
disable_random_init,
79+
enable_raw_embedding_streaming,
80+
res_store_shards,
81+
res_server_port,
82+
std::move(table_names),
83+
std::move(table_offsets),
84+
std::move(table_sizes));
6785
} else {
6886
throw std::runtime_error("Failed to create recording device");
6987
}

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,6 +528,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
528528

529529
// -- commone path
530530
std::atomic<int64_t> total_cache_update_duration_{0};
531+
532+
protected:
531533
std::unique_ptr<fbgemm_gpu::RawEmbeddingStreamer> raw_embedding_streamer_;
532534
}; // class EmbeddingKVDB
533535

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -978,7 +978,13 @@ static auto dram_kv_embedding_cache_wrapper =
978978
std::optional<at::Tensor>,
979979
bool,
980980
bool,
981-
bool>(),
981+
bool,
982+
bool,
983+
int64_t,
984+
int64_t,
985+
std::vector<std::string>,
986+
std::vector<int64_t>,
987+
std::vector<int64_t>>(),
982988
"",
983989
{
984990
torch::arg("max_D"),
@@ -993,6 +999,12 @@ static auto dram_kv_embedding_cache_wrapper =
993999
torch::arg("backend_return_whole_row") = false,
9941000
torch::arg("enable_async_update") = false,
9951001
torch::arg("disable_random_init") = false,
1002+
torch::arg("enable_raw_embedding_streaming") = false,
1003+
torch::arg("res_store_shards") = 0,
1004+
torch::arg("res_server_port") = 0,
1005+
torch::arg("table_names") = std::vector<std::string>{},
1006+
torch::arg("table_offsets") = std::vector<int64_t>{},
1007+
torch::arg("table_sizes") = std::vector<int64_t>{},
9961008
})
9971009
.def(
9981010
"set_cuda",

0 commit comments

Comments
 (0)