diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py index 1aff0ecf6..a245b4dda 100644 --- a/torchrec/distributed/batched_embedding_kernel.py +++ b/torchrec/distributed/batched_embedding_kernel.py @@ -166,6 +166,17 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]: ) ssd_tbe_params["cache_sets"] = int(max_cache_sets) + ssd_tbe_params["table_names"] = [table.name for table in config.embedding_tables] + ssd_tbe_params["table_offsets"] = [] + for emb_tbl in config.embedding_tables: + local_metadata = emb_tbl.local_metadata + if ( + local_metadata is not None + and local_metadata.shard_offsets is not None + and len(local_metadata.shard_offsets) >= 1 + ): + ssd_tbe_params["table_offsets"].append(local_metadata.shard_offsets[0]) + return ssd_tbe_params diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py index ac7260d25..b021cddc8 100644 --- a/torchrec/distributed/types.py +++ b/torchrec/distributed/types.py @@ -639,6 +639,7 @@ class KeyValueParams: enable_async_update: Optional[bool]: whether to enable async update for l2 cache bulk_init_chunk_size: int: number of rows to insert into rocksdb in each chunk lazy_bulk_init_enabled: bool: whether to enable lazy(async) bulk init for SSD TBE + enable_raw_embedding_streaming: Optional[bool]: enable raw embedding streaming for SSD TBE # Parameter Server (PS) Attributes ps_hosts (Optional[Tuple[Tuple[str, int]]]): List of PS host ip addresses @@ -660,6 +661,9 @@ class KeyValueParams: enable_async_update: Optional[bool] = None # enable L2 cache async update bulk_init_chunk_size: Optional[int] = None # number of rows lazy_bulk_init_enabled: Optional[bool] = None # enable lazy bulk init + enable_raw_embedding_streaming: Optional[bool] = ( + None # enable raw embedding streaming for SSD TBE + ) # Parameter Server (PS) Attributes ps_hosts: Optional[Tuple[Tuple[str, int], ...]] = None @@ -686,6 +690,7 @@ def __hash__(self) -> int: self.enable_async_update, self.bulk_init_chunk_size, self.lazy_bulk_init_enabled, + self.enable_raw_embedding_streaming, ) )