Skip to content

Commit bb1cbc2

Browse files
committed
small changes for kvzch (#4073)
Summary: Pull Request resolved: #4073 X-link: facebookresearch/FBGEMM#1157 change set 1. introduce 2 new type to better introduce KVZCH into SSD TBE 2. add virtual indicator for PartiallyMaterializeTensor so that when checkpoint and publish see PMT, they are able to tell whether it is for normal SSD emb or kv zch embedding 3. update hash mode name to chunk-based or interleaved-based 4. change id and bucket shape to 2D tensor instead of 1D Differential Revision: D74137570
1 parent d17c6d9 commit bb1cbc2

File tree

8 files changed

+69
-20
lines changed

8 files changed

+69
-20
lines changed

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_common.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import enum
1313
from dataclasses import dataclass
14-
from typing import List, NamedTuple
14+
from typing import List, NamedTuple, Tuple
1515

1616
import torch
1717
from torch import Tensor
@@ -49,6 +49,33 @@ def from_str(cls, key: str):
4949
raise ValueError(f"Cannot parse value into EmbeddingLocation: {key}")
5050

5151

52+
class KVZCHParams(NamedTuple):
53+
# global bucket id start and global bucket id end offsets for each logical table,
54+
# where start offset is inclusive and end offset is exclusive
55+
bucket_offsets: List[Tuple[int, int]] = []
56+
# bucket size for each logical table
57+
# the value indicates corresponding input space for each bucket id, e.g. 2^50 / total_num_buckets
58+
bucket_sizes: List[int] = []
59+
60+
61+
class BackendType(enum.IntEnum):
62+
SSD = 0
63+
DRAM = 1
64+
PS = 2
65+
66+
@classmethod
67+
# pyre-ignore[3]
68+
def from_str(cls, key: str):
69+
lookup = {
70+
"ssd": BackendType.SSD,
71+
"dram": BackendType.DRAM,
72+
}
73+
if key in lookup:
74+
return lookup[key]
75+
else:
76+
raise ValueError(f"Cannot parse value into BackendType: {key}")
77+
78+
5279
class CacheAlgorithm(enum.Enum):
5380
LRU = 0
5481
LFU = 1

fbgemm_gpu/fbgemm_gpu/tbe/ssd/utils/partially_materialized_tensor.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class PartiallyMaterializedTensor:
3333
or use `full_tensor()` to get the full tensor (this could OOM).
3434
"""
3535

36-
def __init__(self, wrapped) -> None:
36+
def __init__(self, wrapped, is_virtual: bool = False) -> None:
3737
"""
3838
Ensure caller loads the module before creating this object.
3939
@@ -48,6 +48,7 @@ def __init__(self, wrapped) -> None:
4848
wrapped: torch.classes.fbgemm.KVTensorWrapper
4949
"""
5050
self._wrapped = wrapped
51+
self._is_virtual = is_virtual
5152
self._requires_grad = False
5253

5354
@property
@@ -57,6 +58,17 @@ def wrapped(self):
5758
"""
5859
return self._wrapped
5960

61+
@property
62+
def is_virtual(self):
63+
"""
64+
Indicate whether PMT is a virtual tensor.
65+
This indicator is needed for checkpoint or publish.
66+
They need to know wheether it is PMT for kvzch or for normal emb table
67+
for kvzch, checkpoint and publish need to call all-gather to recalculate the correct
68+
metadata of the ShardedTensor
69+
"""
70+
return self._is_virtual
71+
6072
@classmethod
6173
def __torch_function__(cls, func, types, args=(), kwargs=None):
6274
if kwargs is None:

fbgemm_gpu/include/fbgemm_gpu/split_embeddings_cache/kv_db_cpp_utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ inline size_t hash_shard(int64_t id, size_t num_shards) {
4646
///
4747
/// @param unordered_indices unordered ids, the id here might be
4848
/// original(unlinearized) id
49-
/// @param hash_mode 0 for hash by mod, 1 for hash by interleave
49+
/// @param hash_mode 0 for chunk-based hashing, 1 for interleaved-based hashing
5050
/// @param bucket_start global bucket id, the start of the bucket range
5151
/// @param bucket_end global bucket id, the end of the bucket range
5252
/// @param bucket_size an optional, virtual size(input space, e.g. 2^50) of a

fbgemm_gpu/src/split_embeddings_cache/kv_db_cpp_utils.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ int64_t _get_bucket_id(
2323
std::optional<int64_t> total_num_buckets = std::nullopt) {
2424
if (hash_mode == 0) {
2525
CHECK(bucket_size.has_value());
26-
// hash by mod
26+
// chunk-based hashing
2727
return id / bucket_size.value();
2828
} else {
29-
// hash by interleave
29+
// interleave-based hashing
3030
CHECK(total_num_buckets.has_value());
3131
return id % total_num_buckets.value();
3232
}
@@ -42,7 +42,7 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
4242
TORCH_CHECK(unordered_indices.is_contiguous());
4343
TORCH_CHECK(
4444
hash_mode == 0 || hash_mode == 1,
45-
"only support hash by mod and interleaved for now");
45+
"only support hash by chunk-based or interleaved-based hashing for now");
4646
TORCH_CHECK(
4747
bucket_start <= bucket_end,
4848
"bucket_start:",
@@ -73,11 +73,16 @@ std::tuple<at::Tensor, at::Tensor> get_bucket_sorted_indices_and_bucket_tensor(
7373
for (int64_t i = 0; i < num_indices; ++i) {
7474
auto global_bucket_id = _get_bucket_id(
7575
indices_data_ptr[i], hash_mode, bucket_size, total_num_buckets);
76-
CHECK(global_bucket_id >= bucket_start && global_bucket_id < bucket_end)
77-
<< "indices: " << indices_data_ptr[i]
78-
<< " bucket id: " << global_bucket_id
79-
<< " must fall into the range between:" << bucket_start << " and "
80-
<< bucket_end;
76+
TORCH_CHECK(
77+
global_bucket_id >= bucket_start && global_bucket_id < bucket_end,
78+
"indices: ",
79+
indices_data_ptr[i],
80+
" bucket id: ",
81+
global_bucket_id,
82+
" must fall into the range between:",
83+
bucket_start,
84+
" and ",
85+
bucket_end);
8186
if (bucket_id_to_cnt.find(global_bucket_id) == bucket_id_to_cnt.end()) {
8287
bucket_id_to_cnt[global_bucket_id] = 0;
8388
}

fbgemm_gpu/src/ssd_split_embeddings_cache/embedding_rocksdb_wrapper.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,14 @@ class EmbeddingRocksDBWrapper : public torch::jit::CustomClassHolder {
8686
int64_t start_id,
8787
int64_t end_id,
8888
int64_t id_offset,
89-
c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper> snapshot_handle) {
89+
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
90+
snapshot_handle) {
9091
return impl_->get_keys_in_range_by_snapshot(
91-
start_id, end_id, id_offset, snapshot_handle->handle);
92+
start_id,
93+
end_id,
94+
id_offset,
95+
snapshot_handle.has_value() ? snapshot_handle.value()->handle
96+
: nullptr);
9297
}
9398

9499
void toggle_compaction(bool enable) {

fbgemm_gpu/src/ssd_split_embeddings_cache/kv_tensor_wrapper.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class KVTensorWrapper : public torch::jit::CustomClassHolder {
4444
int64_t dtype,
4545
int64_t row_offset,
4646
std::optional<c10::intrusive_ptr<EmbeddingSnapshotHandleWrapper>>
47-
snapshot_handle);
47+
snapshot_handle = std::nullopt);
4848

4949
at::Tensor narrow(int64_t dim, int64_t start, int64_t length);
5050

fbgemm_gpu/src/ssd_split_embeddings_cache/ssd_table_batched_embeddings.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ class EmbeddingRocksDB : public kv_db::EmbeddingKVDB {
547547
}
548548

549549
at::Tensor returned_keys = at::empty(
550-
total_num, at::TensorOptions().device(at::kCPU).dtype(at::kLong));
550+
{total_num, 1}, at::TensorOptions().device(at::kCPU).dtype(at::kLong));
551551
auto key_ptr = returned_keys.data_ptr<int64_t>();
552552
int64_t offset = 0;
553553
for (const auto& keys : keys_in_db_shards) {

fbgemm_gpu/test/tbe/ssd/ssd_l2_cache_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,6 @@ def test_rocksdb_get_discrete_ids(
269269
mixed: bool,
270270
weights_precision: SparseType,
271271
) -> None:
272-
weights_precision: SparseType = SparseType.FP32
273272
emb, Es, Ds, max_D = self.generate_fbgemm_ssd_tbe(
274273
T, D, log_E, weights_precision, mixed, False, 8
275274
)
@@ -306,7 +305,7 @@ def test_rocksdb_get_discrete_ids(
306305
start_id + offset, end_id + offset, offset, snapshot
307306
)
308307
ids_in_range_ordered, _ = torch.sort(ids_in_range)
309-
id_tensor_ordered, _ = torch.sort(id_tensor)
308+
id_tensor_ordered, _ = torch.sort(id_tensor.view(-1))
310309

311310
assert torch.equal(ids_in_range_ordered, id_tensor_ordered)
312311

@@ -377,7 +376,8 @@ def test_get_bucket_sorted_indices(
377376
else:
378377
# test failure
379378
with self.assertRaisesRegex(
380-
RuntimeError, "only support hash by mod and interleaved for now"
379+
RuntimeError,
380+
"only support hash by chunk-based or interleaved-based hashing for now",
381381
):
382382
torch.ops.fbgemm.get_bucket_sorted_indices_and_bucket_tensor(
383383
indices,
@@ -400,7 +400,7 @@ def test_get_bucket_sorted_indices(
400400
last_bucket_id = cur_bucket_id
401401
# Calculate expected tensor output
402402
expected_bucket_tensor = torch.zeros(
403-
bucket_end - bucket_start, 1, dtype=torch.int64
403+
bucket_end - bucket_start, dtype=torch.int64
404404
)
405405
for index in indices:
406406
self.assertTrue(hash_mode >= 0 and hash_mode <= 1)
@@ -412,4 +412,4 @@ def test_get_bucket_sorted_indices(
412412
expected_bucket_tensor[bucket_id - bucket_start] += 1
413413

414414
# Compare actual and expected tensor outputs
415-
self.assertTrue(torch.equal(bucket_t, expected_bucket_tensor))
415+
self.assertTrue(torch.equal(bucket_t.view(-1), expected_bucket_tensor))

0 commit comments

Comments
 (0)