Skip to content

Commit d6031f9

Browse files
chouxifacebook-github-bot
authored andcommitted
Add raw embedding streaming needed params in trec and mvai (#2935)
Summary: Pull Request resolved: #2935 Add the variables needed in D73792631 to mvai and torch rec to be able to control them via config. Reviewed By: aliafzal Differential Revision: D74086201 fbshipit-source-id: 53fb269c17f08d87589a837d2049b733db0d665e
1 parent cea9f07 commit d6031f9

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

torchrec/distributed/batched_embedding_kernel.py

+18
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
DenseTableBatchedEmbeddingBagsCodegen,
3838
EmbeddingLocation,
3939
PoolingMode,
40+
RESParams,
4041
SparseType,
4142
SplitTableBatchedEmbeddingBagsCodegen,
4243
)
@@ -171,6 +172,23 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
171172
)
172173
ssd_tbe_params["cache_sets"] = int(max_cache_sets)
173174

175+
# populate res_params, which is used for raw embedding streaming
176+
# here only populates the params available in fused_params and TBE configs
177+
res_params: RESParams = RESParams()
178+
res_params.table_names = [table.name for table in config.embedding_tables]
179+
res_params.table_offsets = []
180+
for emb_tbl in config.embedding_tables:
181+
local_metadata = emb_tbl.local_metadata
182+
if (
183+
local_metadata is not None
184+
and local_metadata.shard_offsets is not None
185+
and len(local_metadata.shard_offsets) >= 1
186+
):
187+
res_params.table_offsets.append(local_metadata.shard_offsets[0])
188+
if "res_store_shards" in fused_params:
189+
res_params.res_store_shards = fused_params.get("res_store_shards")
190+
ssd_tbe_params["res_params"] = res_params
191+
174192
return ssd_tbe_params
175193

176194

torchrec/distributed/types.py

+3
Original file line numberDiff line numberDiff line change
@@ -640,6 +640,7 @@ class KeyValueParams:
640640
bulk_init_chunk_size: int: number of rows to insert into rocksdb in each chunk
641641
lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE
642642
enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE
643+
res_store_shards: Optional[int] = None: the number of shards to store the raw embeddings
643644
644645
# Parameter Server (PS) Attributes
645646
ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses
@@ -664,6 +665,7 @@ class KeyValueParams:
664665
enable_raw_embedding_streaming: Optional[bool] = (
665666
None # enable raw embedding streaming for SSD TBE
666667
)
668+
res_store_shards: Optional[int] = None # shards to store the raw embeddings
667669

668670
# Parameter Server (PS) Attributes
669671
ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None
@@ -691,6 +693,7 @@ def __hash__(self) -> int:
691693
self.bulk_init_chunk_size,
692694
self.lazy_bulk_init_enabled,
693695
self.enable_raw_embedding_streaming,
696+
self.res_store_shards,
694697
)
695698
)
696699

0 commit comments

Comments
 (0)