From 1eb623412bd717d5aa3edb42f3ad010fa8ae884f Mon Sep 17 00:00:00 2001 From: Zheng Qi Date: Wed, 30 Apr 2025 16:01:01 -0700 Subject: [PATCH 1/2] Add enable_raw_embedding_streaming from TBE config to EmbeddingKVDB (#2928) Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/1138 X-link: https://github.com/pytorch/FBGEMM/pull/4053 As titled, add this option all the way to gate the upcoming changes of raw embedding streaming in SSDTBE. Differential Revision: D73691088 --- torchrec/distributed/types.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index ac7260d25..b021cddc8 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -639,6 +639,7 @@ class KeyValueParams: enable_async_update: Optional[bool]: whether to enable async update for l2 cache bulk_init_chunk_size: int: number of rows to insert into rocksdb in each chunk lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE + enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -660,6 +661,9 @@ class KeyValueParams: enable_async_update: Optional[bool] = None # enable L2 cache async update bulk_init_chunk_size: Optional[int] = None # number of rows lazy_bulk_init_enabled: Optional[bool] = None # enable lazy bulk init + enable_raw_embedding_streaming: Optional[bool] = ( + None # enable raw embedding streaming for SSD TBE + ) # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -686,6 +690,7 @@ def __hash__(self) -> int: self.enable_async_update, self.bulk_init_chunk_size, self.lazy_bulk_init_enabled, + self.enable_raw_embedding_streaming, ) ) From a9fb67f1a4ea5f52851e685176fe743504157f00 Mon Sep 17 00:00:00 2001 From: Zheng Qi Date: Wed, 30 Apr 2025 16:01:01 -0700 Subject: [PATCH 2/2] Add logic to stream weights in EmbeddingKVDB Summary: Gated by enable_raw_embedding_streaming Add the logic to send the passed in tensors to `TrainingParameterServerService` thrift service in EmbeddingKVDB The passed in - `table_names` to get the table FQN when streaming - `table_offsets` to get the global row id across TBEs. - `table_sizes` to get size of each table in TBE to infer which table a specific row belongs to. - `ps_server_port` is the port that runs the local `TrainingParameterServerService` to stream tensors to. It creates a new thread `weights_stream_thread_` in EmbeddingKBDB to stream the weights out of trainers asynchronously. Differential Revision: D73792631 --- torchrec/distributed/batched_embedding_kernel.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1aff0ecf6..a245b4dda 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -166,6 +166,17 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) ssd_tbe_params["cache_sets"] = int(max_cache_sets) + ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables] + ssd_tbe_params["table_offsets"] = [] + for emb_tbl in config.embedding_tables: + local_metadata = emb_tbl.local_metadata + if ( + local_metadata is not None + and local_metadata.shard_offsets is not None + and len(local_metadata.shard_offsets) >= 1 + ): + ssd_tbe_params["table_offsets"].append(local_metadata.shard_offsets[0]) + return ssd_tbe_params