Skip to content

Support ssd device propagation in Torch Rec for RecSys Inference #2961

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
ShardedEmbeddingModule,
ShardingType,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
)
from torchrec.distributed.sharding.cw_sequence_sharding import (
CwSequenceEmbeddingSharding,
)
Expand Down Expand Up @@ -184,9 +188,16 @@ def create_sharding_infos_by_sharding_device_group(
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

device_group: TypeUnion[str, Tuple[str, ...]] = (
get_device_from_parameter_sharding(parameter_sharding)
)
# if a table name is overridden to be offloaded to ssd storage for inference
# update the device group accordingly
if fused_params and table_name in fused_params.get(
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST, {}
):
device_group: TypeUnion[str, Tuple[str, ...]] = "ssd"
else:
device_group: TypeUnion[str, Tuple[str, ...]] = (
get_device_from_parameter_sharding(parameter_sharding)
)
if (
parameter_sharding.sharding_type,
device_group,
Expand Down Expand Up @@ -214,6 +225,8 @@ def create_sharding_infos_by_sharding_device_group(
per_table_fused_params, parameter_sharding
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
if device_group == "ssd":
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: True})

sharding_type_device_group_to_sharding_infos[
(parameter_sharding.sharding_type, device_group)
Expand Down
6 changes: 5 additions & 1 deletion torchrec/distributed/embedding_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ListOfKJTList,
ShardedEmbeddingTable,
)
from torchrec.distributed.fused_params import FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST
from torchrec.distributed.types import (
Awaitable,
EmbeddingEvent,
Expand Down Expand Up @@ -420,7 +421,7 @@ def _get_grouping_fused_params(
) -> Optional[Dict[str, Any]]:
"""
Only shallow copy the fused params we need for grouping tables into TBEs. In
particular, we do not copy cache_load_factor.
particular, we do not copy cache_load_factor or ssd embedding table list.
"""
grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params)

Expand All @@ -430,6 +431,9 @@ def _get_grouping_fused_params(
if CACHE_LOAD_FACTOR_STR in grouping_fused_params:
del grouping_fused_params[CACHE_LOAD_FACTOR_STR]

if FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST in grouping_fused_params:
del grouping_fused_params[FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST]

if grouping_fused_params.get(USE_ONE_TBE_PER_TABLE, False):
# Replace with unique value to force it into singleton group.
# Name is used as unique value so we won't group multiple shard belonging
Expand Down
17 changes: 16 additions & 1 deletion torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
KJTList,
ShardedEmbeddingModule,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
)
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.dynamic_sharding import (
Expand Down Expand Up @@ -227,7 +231,16 @@ def create_sharding_infos_by_sharding_device_group(
assert param_name in parameter_by_name or param_name in state_dict
param = parameter_by_name.get(param_name, state_dict[param_name])

device_group = get_device_from_parameter_sharding(parameter_sharding)
# if a table name is overridden to be offloaded to ssd storage for inference
# update the device group accordingly
if fused_params and table_name in fused_params.get(
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST, {}
):
device_group: Union[str, Tuple[str, ...]] = "ssd"
else:
device_group: Union[str, Tuple[str, ...]] = (
get_device_from_parameter_sharding(parameter_sharding)
)

if (
parameter_sharding.sharding_type,
Expand Down Expand Up @@ -257,6 +270,8 @@ def create_sharding_infos_by_sharding_device_group(
per_table_fused_params, parameter_sharding
)
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
if device_group == "ssd":
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: True})

sharding_type_device_group_to_sharding_infos[
(parameter_sharding.sharding_type, device_group)
Expand Down
7 changes: 7 additions & 0 deletions torchrec/distributed/fused_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
# with certain ways to split models.
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"

# List of cpu embedding tables offloaded to ssd to scale the embedding table size
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST: str = "__register_ssd_table_placement_list"
# Bool param per table to check if the table is offloaded to SSD
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT: str = "__register_is_ssd_table_placement"


class TBEToRegisterMixIn:
def get_tbes_to_register(
Expand Down Expand Up @@ -111,5 +116,7 @@ def tbe_fused_params(
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)
if FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST in fused_params_for_tbe:
fused_params_for_tbe.pop(FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST)

return fused_params_for_tbe
42 changes: 27 additions & 15 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ShardingType,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
get_tbes_to_register_from_iterable,
Expand Down Expand Up @@ -173,12 +174,19 @@ def get_device_from_parameter_sharding(
def get_device_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> Union[str, Tuple[str, ...]]:
res = list(
{
get_device_from_parameter_sharding(ps.param_sharding)
for ps in emb_shard_infos
}
)
res_set = set()
for emb_shard_info in emb_shard_infos:
if (
emb_shard_info.fused_params
and FUSED_PARAM_IS_SSD_TABLE_PLACEMENT in emb_shard_info.fused_params
and emb_shard_info.fused_params[FUSED_PARAM_IS_SSD_TABLE_PLACEMENT]
):
res_set.add("ssd")
else:
res_set.add(
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
)
res = list(res_set)
assert len(res) == 1, "All shards should be on the same type of device"
return res[0]

Expand All @@ -201,11 +209,11 @@ def create_infer_embedding_sharding(
List[torch.Tensor],
List[torch.Tensor],
]:
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
get_device_from_sharding_infos(sharding_infos)
)

if device_type_from_sharding_infos in ["cuda", "mtia"]:
if storage_device_type_from_sharding_infos in ["cuda", "mtia"]:
if sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
elif sharding_type == ShardingType.COLUMN_WISE.value:
Expand All @@ -215,31 +223,31 @@ def create_infer_embedding_sharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
)
elif device_type_from_sharding_infos == "cpu" or isinstance(
device_type_from_sharding_infos, tuple
elif storage_device_type_from_sharding_infos in ["cpu", "ssd"] or isinstance(
storage_device_type_from_sharding_infos, tuple
):
if sharding_type == ShardingType.ROW_WISE.value:
return InferRwSequenceEmbeddingSharding(
sharding_infos=sharding_infos,
env=env,
device=device,
device_type_from_sharding_infos=device_type_from_sharding_infos,
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.TABLE_WISE.value:
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
)
else:
raise ValueError(
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
)


Expand Down Expand Up @@ -542,6 +550,10 @@ def __init__(
module, table_name_to_parameter_sharding, fused_params
)

for x, y in self._sharding_type_device_group_to_sharding_infos.items():
print(f"SHARDING INFO: {x}")
print("=========================")

self._sharding_type_device_group_to_sharding: Dict[
Tuple[str, Union[str, Tuple[str, ...]]],
EmbeddingSharding[
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/quant_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torchrec.distributed.fused_params import (
fused_param_bounds_check_mode,
fused_param_lengths_to_offsets_lookup,
FUSED_PARAM_SSD_TABLE_PLACEMENT_LIST,
is_fused_param_quant_state_dict_split_scale_bias,
is_fused_param_register_tbe,
tbe_fused_params,
Expand Down
22 changes: 14 additions & 8 deletions torchrec/distributed/quant_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
create_sharding_infos_by_sharding_device_group,
)
from torchrec.distributed.fused_params import (
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT,
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
FUSED_PARAM_REGISTER_TBE_BOOL,
get_tbes_to_register_from_iterable,
Expand Down Expand Up @@ -97,12 +98,17 @@ def get_device_from_parameter_sharding(
def get_device_from_sharding_infos(
emb_shard_infos: List[EmbeddingShardingInfo],
) -> Union[str, Tuple[str, ...]]:
res = list(
{
get_device_from_parameter_sharding(ps.param_sharding)
for ps in emb_shard_infos
}
)
res_set = set()
for emb_shard_info in emb_shard_infos:
if emb_shard_info.fused_params and emb_shard_info.fused_params.get(
FUSED_PARAM_IS_SSD_TABLE_PLACEMENT, False
):
res_set.add("ssd")
else:
res_set.add(
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
)
res = list(res_set)
assert len(res) == 1, "All shards should be on the same type of device"
return res[0]

Expand Down Expand Up @@ -131,7 +137,7 @@ def create_infer_embedding_bag_sharding(
NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor
]:
propogate_device: bool = get_propogate_device()
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
get_device_from_sharding_infos(sharding_infos)
)
if sharding_type == ShardingType.TABLE_WISE.value:
Expand All @@ -143,7 +149,7 @@ def create_infer_embedding_bag_sharding(
sharding_infos,
env,
device=device if propogate_device else None,
device_type_from_sharding_infos=device_type_from_sharding_infos,
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
)
elif sharding_type == ShardingType.COLUMN_WISE.value:
return InferCwPooledEmbeddingSharding(
Expand Down
5 changes: 4 additions & 1 deletion torchrec/distributed/sharding/rw_sequence_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,9 @@ def forward(
# using _device_type_from_sharding_infos to iterate on local_embs list as
# that's a better practice.
for i, device_type in enumerate(self._device_type_from_sharding_infos):
assert (
device_type != "ssd"
), "Heterogenous sharding across multiple storage device types for a single table not supported for ssd stroage device type"
if device_type != "cpu":
non_cpu_local_embs.append(
_get_batching_hinted_output(
Expand All @@ -235,7 +238,7 @@ def forward(
result.append(non_cpu_local_embs_dist[index])
index += 1
return result
elif self._device_type_from_sharding_infos == "cpu":
elif self._device_type_from_sharding_infos in ["cpu", "ssd"]:
# for cpu sharder, output dist should be a no-op
return local_embs
else:
Expand Down
Loading