Skip to content

Commit 792e68a

Browse files
duduyi2013facebook-github-bot
authored andcommitted
chunk processing l2 cache flush (#4216)
Summary: Pull Request resolved: #4216 X-link: facebookresearch/FBGEMM#1292 as title, when we have a large L2 cache, flush will double up the mem footprint Reviewed By: emlin Differential Revision: D75314575 fbshipit-source-id: c0963665aed9065d833bc94e961a528d239b1ada
1 parent d4061ad commit 792e68a

File tree

9 files changed

+87
-51
lines changed

9 files changed

+87
-51
lines changed

fbgemm_gpu/fbgemm_gpu/tbe/ssd/training.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ def __init__(
168168
kv_zch_params: Optional[KVZCHParams] = None,
169169
enable_raw_embedding_streaming: bool = False, # whether enable raw embedding streaming
170170
res_params: Optional[RESParams] = None, # raw embedding streaming sharding info
171+
flushing_block_size: int = 2_000_000_000, # 2GB
171172
) -> None:
172173
super(SSDTableBatchedEmbeddingBags, self).__init__()
173174

@@ -520,15 +521,19 @@ def __init__(
520521
logging.info(f"tbe_unique_id: {tbe_unique_id}")
521522
if self.backend_type == BackendType.SSD:
522523
logging.info(
523-
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, enable_async_update:{enable_async_update}"
524-
f"passed_in_path={ssd_directory}, num_shards={ssd_rocksdb_shards},num_threads={ssd_rocksdb_shards},"
525-
f"memtable_flush_period={ssd_memtable_flush_period},memtable_flush_offset={ssd_memtable_flush_offset},"
526-
f"l0_files_per_compact={ssd_l0_files_per_compact},max_D={self.max_D},cache_row_dim={self.cache_row_dim},rate_limit_mbps={ssd_rate_limit_mbps},"
527-
f"size_ratio={ssd_size_ratio},compaction_trigger={ssd_compaction_trigger}, lazy_bulk_init_enabled={lazy_bulk_init_enabled},"
528-
f"write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size},max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num},"
529-
f"uniform_init_lower={ssd_uniform_init_lower},uniform_init_upper={ssd_uniform_init_upper},"
530-
f"row_storage_bitwidth={weights_precision.bit_rate()},block_cache_size_per_tbe={ssd_block_cache_size_per_tbe},"
531-
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}"
524+
f"Logging SSD offloading setup, tbe_unique_id:{tbe_unique_id}, l2_cache_size:{l2_cache_size}GB, "
525+
f"enable_async_update:{enable_async_update}, passed_in_path={ssd_directory}, "
526+
f"num_shards={ssd_rocksdb_shards}, num_threads={ssd_rocksdb_shards}, "
527+
f"memtable_flush_period={ssd_memtable_flush_period}, memtable_flush_offset={ssd_memtable_flush_offset}, "
528+
f"l0_files_per_compact={ssd_l0_files_per_compact}, max_D={self.max_D}, "
529+
f"cache_row_size={self.cache_row_dim}, rate_limit_mbps={ssd_rate_limit_mbps}, "
530+
f"size_ratio={ssd_size_ratio}, compaction_trigger={ssd_compaction_trigger}, "
531+
f"lazy_bulk_init_enabled={lazy_bulk_init_enabled}, write_buffer_size_per_tbe={ssd_rocksdb_write_buffer_size}, "
532+
f"max_write_buffer_num_per_db_shard={ssd_max_write_buffer_num}, "
533+
f"uniform_init_lower={ssd_uniform_init_lower}, uniform_init_upper={ssd_uniform_init_upper}, "
534+
f"row_storage_bitwidth={weights_precision.bit_rate()}, block_cache_size_per_tbe={ssd_block_cache_size_per_tbe}, "
535+
f"use_passed_in_path:{use_passed_in_path}, real_path will be printed in EmbeddingRocksDB, "
536+
f"enable_raw_embedding_streaming:{self.enable_raw_embedding_streaming}, flushing_block_size:{flushing_block_size}"
532537
)
533538
# pyre-fixme[4]: Attribute must be annotated.
534539
self._ssd_db = torch.classes.fbgemm.EmbeddingRocksDBWrapper(
@@ -568,6 +573,7 @@ def __init__(
568573
if self.enable_optimizer_offloading
569574
else None
570575
),
576+
flushing_block_size,
571577
)
572578
if self.bulk_init_chunk_size > 0:
573579
self.ssd_uniform_init_lower: float = ssd_uniform_init_lower

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/cachelib_cache.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ class CacheLibCache {
4343
const CacheConfig& cache_config,
4444
int64_t unique_tbe_id);
4545

46+
size_t get_cache_item_size() const;
47+
Cache::AccessIterator begin();
4648
std::unique_ptr<Cache> initializeCacheLib(const CacheConfig& config);
4749

4850
std::unique_ptr<facebook::cachelib::CacheAdmin> createCacheAdmin(
@@ -99,7 +101,7 @@ class CacheLibCache {
99101
/// @note cache_->allocation will trigger eviction callback func
100102
bool put(const at::Tensor& key_tensor, const at::Tensor& data);
101103

102-
/// iterate through all items in L2 cache, fill them in indices and weights
104+
/// iterate through N items in L2 cache, fill them in indices and weights
103105
/// respectively and return indices, weights and count
104106
///
105107
/// @return optional value, if cache is empty return none
@@ -109,11 +111,11 @@ class CacheLibCache {
109111
/// relative element in <indices>
110112
/// @return count A single element tensor that contains the number of indices
111113
/// to be processed
112-
///
113114
/// @note this isn't thread safe, caller needs to make sure put isn't called
114115
/// while this is executed.
115-
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
116-
get_all_items();
116+
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>> get_n_items(
117+
int n,
118+
Cache::AccessIterator& start_itr);
117119

118120
/// instantiate eviction related indices and weights tensors(size of <count>)
119121
/// for L2 eviction using the same dtype and device from <indices> and

fbgemm_gpu/src/split_embeddings_cache/cachelib_cache.cpp

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,14 @@ CacheLibCache::CacheLibCache(
3434
}
3535
}
3636

37+
size_t CacheLibCache::get_cache_item_size() const {
38+
return cache_config_.item_size_bytes;
39+
}
40+
41+
Cache::AccessIterator CacheLibCache::begin() {
42+
return cache_->begin();
43+
}
44+
3745
std::unique_ptr<Cache> CacheLibCache::initializeCacheLib(
3846
const CacheConfig& config) {
3947
auto eviction_cb = [this](
@@ -177,50 +185,41 @@ bool CacheLibCache::put(const at::Tensor& key_tensor, const at::Tensor& data) {
177185
}
178186

179187
folly::Optional<std::tuple<at::Tensor, at::Tensor, at::Tensor>>
180-
CacheLibCache::get_all_items() {
188+
CacheLibCache::get_n_items(int n, Cache::AccessIterator& itr) {
181189
if (!index_dtype_.has_value() || !weights_dtype_.has_value()) {
182190
return folly::none;
183191
}
184-
int total_num_items = 0;
185-
for (auto& pool_id : pool_ids_) {
186-
total_num_items += cache_->getPoolStats(pool_id).numItems();
187-
}
188192
auto weight_dim = cache_config_.max_D_;
189193
auto indices = at::empty(
190-
total_num_items,
191-
at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU));
194+
n, at::TensorOptions().dtype(index_dtype_.value()).device(at::kCPU));
192195
auto weights = at::empty(
193-
{total_num_items, weight_dim},
196+
{n, weight_dim},
194197
at::TensorOptions().dtype(weights_dtype_.value()).device(at::kCPU));
198+
int cnt = 0;
195199
FBGEMM_DISPATCH_FLOAT_HALF_AND_BYTE(
196-
weights.scalar_type(), "get_all_items", [&] {
200+
weights.scalar_type(), "get_n_items", [&] {
197201
using value_t = scalar_t;
198202
FBGEMM_DISPATCH_INTEGRAL_TYPES(
199-
indices.scalar_type(), "get_all_items", [&] {
203+
indices.scalar_type(), "get_n_items", [&] {
200204
using index_t = scalar_t;
201205
auto indices_data_ptr = indices.data_ptr<index_t>();
202206
auto weights_data_ptr = weights.data_ptr<value_t>();
203-
int64_t item_idx = 0;
204-
for (auto itr = cache_->begin(); itr != cache_->end(); ++itr) {
207+
for (; itr != cache_->end() && cnt < n; ++itr, ++cnt) {
205208
const auto key_ptr =
206209
reinterpret_cast<const index_t*>(itr->getKey().data());
207-
indices_data_ptr[item_idx] = *key_ptr;
210+
indices_data_ptr[cnt] = *key_ptr;
208211
std::copy(
209212
reinterpret_cast<const value_t*>(itr->getMemory()),
210213
reinterpret_cast<const value_t*>(itr->getMemory()) +
211214
weight_dim,
212-
&weights_data_ptr[item_idx * weight_dim]); // dst_start
213-
item_idx++;
215+
&weights_data_ptr[cnt * weight_dim]); // dst_start
214216
}
215-
CHECK_EQ(total_num_items, item_idx);
216217
});
217218
});
218219
return std::make_tuple(
219220
indices,
220221
weights,
221-
at::tensor(
222-
{total_num_items},
223-
at::TensorOptions().dtype(at::kLong).device(at::kCPU)));
222+
at::tensor({cnt}, at::TensorOptions().dtype(at::kLong).device(at::kCPU)));
224223
}
225224

226225
void CacheLibCache::init_tensor_for_l2_eviction(

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
4343
std::vector<int64_t> table_offsets = {},
4444
const std::vector<int64_t>& table_sizes = {},
4545
std::optional<at::Tensor> table_dims = std::nullopt,
46-
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
46+
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
47+
int64_t flushing_block_size = 2000000000 /*2GB*/)
4748
: impl_(std::make_shared<ssd::EmbeddingRocksDB>(
4849
path,
4950
num_shards,
@@ -72,7 +73,8 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
7273
std::move(table_offsets),
7374
table_sizes,
7475
table_dims,
75-
hash_size_cumsum)) {}
76+
hash_size_cumsum,
77+
flushing_block_size)) {}
7678

7779
void set_cuda(
7880
at::Tensor indices,

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.cpp

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ EmbeddingKVDB::EmbeddingKVDB(
108108
int64_t res_server_port,
109109
std::vector<std::string> table_names,
110110
std::vector<int64_t> table_offsets,
111-
const std::vector<int64_t>& table_sizes)
112-
: unique_id_(unique_id),
111+
const std::vector<int64_t>& table_sizes,
112+
int64_t flushing_block_size)
113+
: flushing_block_size_(flushing_block_size),
114+
unique_id_(unique_id),
113115
num_shards_(num_shards),
114116
max_D_(max_D),
115117
executor_tp_(std::make_unique<folly::CPUThreadPoolExecutor>(num_shards)),
@@ -333,17 +335,29 @@ void EmbeddingKVDB::update_cache_and_storage(
333335
void EmbeddingKVDB::flush() {
334336
wait_util_filling_work_done();
335337
if (l2_cache_) {
336-
auto tensor_tuple_opt = l2_cache_->get_all_items();
337-
if (!tensor_tuple_opt.has_value()) {
338-
XLOG(INFO) << "[TBE_ID" << unique_id_
339-
<< "]no items exist in L2 cache, flush nothing";
340-
return;
338+
int block_size = std::max(
339+
(int)(flushing_block_size_ / l2_cache_->get_cache_item_size()), 1);
340+
folly::Optional<l2_cache::CacheLibCache::Cache::AccessIterator> start_itr =
341+
folly::none;
342+
folly::Optional<at::Tensor> count = folly::none;
343+
auto itr = l2_cache_->begin();
344+
while (count == folly::none || count->item<int64_t>() > 0) {
345+
auto res_tuple_opt = l2_cache_->get_n_items(block_size, itr);
346+
if (!res_tuple_opt.has_value()) {
347+
XLOG(INFO) << "[TBE_ID" << unique_id_
348+
<< "]no items exist in L2 cache, flush nothing";
349+
return;
350+
}
351+
auto& indices = std::get<0>(res_tuple_opt.value());
352+
auto& weights = std::get<1>(res_tuple_opt.value());
353+
count = std::get<2>(res_tuple_opt.value());
354+
355+
if (count->item<int64_t>() > 0) {
356+
set_kv_db_async(
357+
indices, weights, count.value(), kv_db::RocksdbWriteMode::FLUSH)
358+
.wait();
359+
}
341360
}
342-
auto& indices = std::get<0>(tensor_tuple_opt.value());
343-
auto& weights = std::get<1>(tensor_tuple_opt.value());
344-
auto& count = std::get<2>(tensor_tuple_opt.value());
345-
set_kv_db_async(indices, weights, count, kv_db::RocksdbWriteMode::FLUSH)
346-
.wait();
347361
}
348362
}
349363

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_db_table_batched_embeddings.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
142142
int64_t res_server_port = 0,
143143
std::vector<std::string> table_names = {},
144144
std::vector<int64_t> table_offsets = {},
145-
const std::vector<int64_t>& table_sizes = {});
145+
const std::vector<int64_t>& table_sizes = {},
146+
int64_t flushing_block_size = 2000000000 /*2GB*/);
146147

147148
virtual ~EmbeddingKVDB();
148149

@@ -396,6 +397,8 @@ class EmbeddingKVDB : public std::enable_shared_from_this<EmbeddingKVDB> {
396397
const at::Tensor& weights);
397398

398399
std::unique_ptr<l2_cache::CacheLibCache> l2_cache_;
400+
// when flushing l2, the block size in bytes that we flush l2 progressively
401+
int64_t flushing_block_size_;
399402
const int64_t unique_id_;
400403
const int64_t num_shards_;
401404
const int64_t max_D_;

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_split_table_batched_embeddings.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ void KVTensorWrapper::set_range(
379379
const at::Tensor& weights) {
380380
// Mutex lock for disabling concurrent writes to the same KVTensor
381381
std::lock_guard<std::mutex> lock(mtx);
382+
CHECK_EQ(weights.device(), at::kCPU);
382383
CHECK_EQ(dim, 0) << "Only set_range on dim 0 is supported";
383384
CHECK_TRUE(db_ != nullptr);
384385
CHECK_GE(db_->get_max_D(), shape_[1]);
@@ -396,6 +397,7 @@ void KVTensorWrapper::set_range(
396397
void KVTensorWrapper::set_weights_and_ids(
397398
const at::Tensor& weights,
398399
const at::Tensor& ids) {
400+
CHECK_EQ(weights.device(), at::kCPU);
399401
CHECK_TRUE(db_ != nullptr);
400402
CHECK_EQ(ids.size(0), weights.size(0))
401403
<< "ids and weights must have same # rows";
@@ -502,7 +504,8 @@ static auto embedding_rocks_db_wrapper =
502504
std::vector<int64_t>,
503505
std::vector<int64_t>,
504506
std::optional<at::Tensor>,
505-
std::optional<at::Tensor>>(),
507+
std::optional<at::Tensor>,
508+
int64_t>(),
506509
"",
507510
{
508511
torch::arg("path"),
@@ -533,6 +536,7 @@ static auto embedding_rocks_db_wrapper =
533536
torch::arg("table_sizes") = torch::List<int64_t>(),
534537
torch::arg("table_dims") = std::nullopt,
535538
torch::arg("hash_size_cumsum") = std::nullopt,
539+
torch::arg("flushing_block_size") = 2000000000 /* 2GB */,
536540
})
537541
.def(
538542
"set_cuda",

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
103103
std::vector<int64_t> table_offsets = {},
104104
const std::vector<int64_t>& table_sizes = {},
105105
std::optional<at::Tensor> table_dims = std::nullopt,
106-
std::optional<at::Tensor> hash_size_cumsum = std::nullopt)
106+
std::optional<at::Tensor> hash_size_cumsum = std::nullopt,
107+
int64_t flushing_block_size = 2000000000 /*2GB*/)
107108
: kv_db::EmbeddingKVDB(
108109
num_shards,
109110
max_D,
@@ -116,7 +117,8 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
116117
res_server_port,
117118
std::move(table_names),
118119
std::move(table_offsets),
119-
table_sizes),
120+
table_sizes,
121+
flushing_block_size),
120122
auto_compaction_enabled_(true),
121123
max_D_(max_D),
122124
elem_size_(row_storage_bitwidth / 8) {

fbgemm_gpu/test/tbe/ssd/kv_backend_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def generate_fbgemm_kv_tbe(
6161
ssd_rocksdb_shards: int = 1,
6262
kv_zch_params: Optional[KVZCHParams] = None,
6363
backend_type: BackendType = BackendType.SSD,
64+
flushing_block_size: int = 1000,
6465
) -> Tuple[SSDTableBatchedEmbeddingBags, List[int], List[int]]:
6566
E = int(10**log_E)
6667
D = D * 4
@@ -89,6 +90,7 @@ def generate_fbgemm_kv_tbe(
8990
ssd_rocksdb_shards=ssd_rocksdb_shards,
9091
kv_zch_params=kv_zch_params,
9192
backend_type=backend_type,
93+
flushing_block_size=flushing_block_size,
9294
)
9395
return emb, Es, Ds
9496

@@ -103,7 +105,9 @@ def test_l2_flush(
103105
weights_precision: SparseType,
104106
do_flush: bool,
105107
) -> None:
106-
emb, Es, _ = self.generate_fbgemm_kv_tbe(T, D, log_E, weights_precision, mixed)
108+
emb, Es, _ = self.generate_fbgemm_kv_tbe(
109+
T, D, log_E, weights_precision, mixed, flushing_block_size=1
110+
)
107111
indices = torch.arange(start=0, end=sum(Es))
108112
weights = torch.randn(
109113
indices.numel(), emb.cache_row_dim, dtype=weights_precision.as_dtype()

0 commit comments

Comments
 (0)