Skip to content

Commit ceebcf0

Browse files
faran928facebook-github-bot
authored andcommitted
Support ssd device propagation in Torch Rec for RecSys Inference (#2961)
Summary: Pull Request resolved: #2961 For RecSys Inference when tables are offloaded onto SSD: 1. Specify and propagate the tables to be offloaded to SSD in TorchRec via FUSED_PARAMS as discussed with TroyGarden 2. Continue using torch.device("cpu") as compute device while using separate input / output dist for SSD (as in house SSD TBE kernel based on EmbeddingDB is different than CPU TBE kernel) by creating a new device group for SSD. Would be renaming device_type_from_sharding_info to storage_device_type_from_sharding_info to clarify it better. Reviewed By: jiayisuse Differential Revision: D74378974 fbshipit-source-id: ad528cb35230837ccfc9dac23eff8cf4f9adac6f
1 parent d6031f9 commit ceebcf0

File tree

7 files changed

+84
-29
lines changed

7 files changed

+84
-29
lines changed

torchrec/distributed/embedding.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@
4545
ShardedEmbeddingModule,
4646
ShardingType,
4747
)
48+
from torchrec.distributed.fused_params import (
49+
FUSED_PARAM_IS_SSD_TABLE,
50+
FUSED_PARAM_SSD_TABLE_LIST,
51+
)
4852
from torchrec.distributed.sharding.cw_sequence_sharding import (
4953
CwSequenceEmbeddingSharding,
5054
)
@@ -184,9 +188,16 @@ def create_sharding_infos_by_sharding_device_group(
184188
assert param_name in parameter_by_name or param_name in state_dict
185189
param = parameter_by_name.get(param_name, state_dict[param_name])
186190

187-
device_group: TypeUnion[str, Tuple[str, ...]] = (
188-
get_device_from_parameter_sharding(parameter_sharding)
189-
)
191+
# if a table name is overridden to be offloaded to ssd storage for inference
192+
# update the device group accordingly
193+
if fused_params and table_name in fused_params.get(
194+
FUSED_PARAM_SSD_TABLE_LIST, {}
195+
):
196+
device_group: TypeUnion[str, Tuple[str, ...]] = "ssd"
197+
else:
198+
device_group: TypeUnion[str, Tuple[str, ...]] = (
199+
get_device_from_parameter_sharding(parameter_sharding)
200+
)
190201
if (
191202
parameter_sharding.sharding_type,
192203
device_group,
@@ -214,6 +225,8 @@ def create_sharding_infos_by_sharding_device_group(
214225
per_table_fused_params, parameter_sharding
215226
)
216227
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
228+
if device_group == "ssd":
229+
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE: True})
217230

218231
sharding_type_device_group_to_sharding_infos[
219232
(parameter_sharding.sharding_type, device_group)

torchrec/distributed/embedding_sharding.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
ListOfKJTList,
3535
ShardedEmbeddingTable,
3636
)
37+
from torchrec.distributed.fused_params import FUSED_PARAM_SSD_TABLE_LIST
3738
from torchrec.distributed.types import (
3839
Awaitable,
3940
EmbeddingEvent,
@@ -420,7 +421,7 @@ def _get_grouping_fused_params(
420421
) -> Optional[Dict[str, Any]]:
421422
"""
422423
Only shallow copy the fused params we need for grouping tables into TBEs. In
423-
particular, we do not copy cache_load_factor.
424+
particular, we do not copy cache_load_factor or ssd embedding table list.
424425
"""
425426
grouping_fused_params: Optional[Dict[str, Any]] = copy.copy(fused_params)
426427

@@ -430,6 +431,9 @@ def _get_grouping_fused_params(
430431
if CACHE_LOAD_FACTOR_STR in grouping_fused_params:
431432
del grouping_fused_params[CACHE_LOAD_FACTOR_STR]
432433

434+
if FUSED_PARAM_SSD_TABLE_LIST in grouping_fused_params:
435+
del grouping_fused_params[FUSED_PARAM_SSD_TABLE_LIST]
436+
433437
if grouping_fused_params.get(USE_ONE_TBE_PER_TABLE, False):
434438
# Replace with unique value to force it into singleton group.
435439
# Name is used as unique value so we won't group multiple shard belonging

torchrec/distributed/embeddingbag.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@
5151
KJTList,
5252
ShardedEmbeddingModule,
5353
)
54+
from torchrec.distributed.fused_params import (
55+
FUSED_PARAM_IS_SSD_TABLE,
56+
FUSED_PARAM_SSD_TABLE_LIST,
57+
)
5458
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5559
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5660
from torchrec.distributed.sharding.dynamic_sharding import (
@@ -227,7 +231,16 @@ def create_sharding_infos_by_sharding_device_group(
227231
assert param_name in parameter_by_name or param_name in state_dict
228232
param = parameter_by_name.get(param_name, state_dict[param_name])
229233

230-
device_group = get_device_from_parameter_sharding(parameter_sharding)
234+
# if a table name is overridden to be offloaded to ssd storage for inference
235+
# update the device group accordingly
236+
if fused_params and table_name in fused_params.get(
237+
FUSED_PARAM_SSD_TABLE_LIST, {}
238+
):
239+
device_group: Union[str, Tuple[str, ...]] = "ssd"
240+
else:
241+
device_group: Union[str, Tuple[str, ...]] = (
242+
get_device_from_parameter_sharding(parameter_sharding)
243+
)
231244

232245
if (
233246
parameter_sharding.sharding_type,
@@ -257,6 +270,8 @@ def create_sharding_infos_by_sharding_device_group(
257270
per_table_fused_params, parameter_sharding
258271
)
259272
per_table_fused_params = convert_to_fbgemm_types(per_table_fused_params)
273+
if device_group == "ssd":
274+
per_table_fused_params.update({FUSED_PARAM_IS_SSD_TABLE: True})
260275

261276
sharding_type_device_group_to_sharding_infos[
262277
(parameter_sharding.sharding_type, device_group)

torchrec/distributed/fused_params.py

+8
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@
2828
# with certain ways to split models.
2929
FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP: str = "__register_lengths_to_offsets_lookup"
3030

31+
# Fused param storing list of cpu embedding tables offloaded to ssd to scale
32+
# the embedding table size
33+
FUSED_PARAM_SSD_TABLE_LIST: str = "__register_ssd_table_list"
34+
# Bool fused param per table to check if the table is offloaded to SSD
35+
FUSED_PARAM_IS_SSD_TABLE: str = "__register_is_ssd_table"
36+
3137

3238
class TBEToRegisterMixIn:
3339
def get_tbes_to_register(
@@ -111,5 +117,7 @@ def tbe_fused_params(
111117
fused_params_for_tbe.pop(FUSED_PARAM_BOUNDS_CHECK_MODE)
112118
if FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP in fused_params_for_tbe:
113119
fused_params_for_tbe.pop(FUSED_PARAM_LENGTHS_TO_OFFSETS_LOOKUP)
120+
if FUSED_PARAM_SSD_TABLE_LIST in fused_params_for_tbe:
121+
fused_params_for_tbe.pop(FUSED_PARAM_SSD_TABLE_LIST)
114122

115123
return fused_params_for_tbe

torchrec/distributed/quant_embedding.py

+21-15
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
ShardingType,
4848
)
4949
from torchrec.distributed.fused_params import (
50+
FUSED_PARAM_IS_SSD_TABLE,
5051
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
5152
FUSED_PARAM_REGISTER_TBE_BOOL,
5253
get_tbes_to_register_from_iterable,
@@ -173,12 +174,17 @@ def get_device_from_parameter_sharding(
173174
def get_device_from_sharding_infos(
174175
emb_shard_infos: List[EmbeddingShardingInfo],
175176
) -> Union[str, Tuple[str, ...]]:
176-
res = list(
177-
{
178-
get_device_from_parameter_sharding(ps.param_sharding)
179-
for ps in emb_shard_infos
180-
}
181-
)
177+
res_set = set()
178+
for emb_shard_info in emb_shard_infos:
179+
if emb_shard_info.fused_params and emb_shard_info.fused_params.get(
180+
FUSED_PARAM_IS_SSD_TABLE, False
181+
):
182+
res_set.add("ssd")
183+
else:
184+
res_set.add(
185+
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
186+
)
187+
res = list(res_set)
182188
assert len(res) == 1, "All shards should be on the same type of device"
183189
return res[0]
184190

@@ -201,11 +207,11 @@ def create_infer_embedding_sharding(
201207
List[torch.Tensor],
202208
List[torch.Tensor],
203209
]:
204-
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
210+
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
205211
get_device_from_sharding_infos(sharding_infos)
206212
)
207213

208-
if device_type_from_sharding_infos in ["cuda", "mtia"]:
214+
if storage_device_type_from_sharding_infos in ["cuda", "mtia"]:
209215
if sharding_type == ShardingType.TABLE_WISE.value:
210216
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
211217
elif sharding_type == ShardingType.COLUMN_WISE.value:
@@ -215,31 +221,31 @@ def create_infer_embedding_sharding(
215221
sharding_infos=sharding_infos,
216222
env=env,
217223
device=device,
218-
device_type_from_sharding_infos=device_type_from_sharding_infos,
224+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
219225
)
220226
else:
221227
raise ValueError(
222-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
228+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
223229
)
224-
elif device_type_from_sharding_infos == "cpu" or isinstance(
225-
device_type_from_sharding_infos, tuple
230+
elif storage_device_type_from_sharding_infos in ["cpu", "ssd"] or isinstance(
231+
storage_device_type_from_sharding_infos, tuple
226232
):
227233
if sharding_type == ShardingType.ROW_WISE.value:
228234
return InferRwSequenceEmbeddingSharding(
229235
sharding_infos=sharding_infos,
230236
env=env,
231237
device=device,
232-
device_type_from_sharding_infos=device_type_from_sharding_infos,
238+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
233239
)
234240
elif sharding_type == ShardingType.TABLE_WISE.value:
235241
return InferTwSequenceEmbeddingSharding(sharding_infos, env, device)
236242
else:
237243
raise ValueError(
238-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
244+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
239245
)
240246
else:
241247
raise ValueError(
242-
f"Sharding type not supported {sharding_type} for {device_type_from_sharding_infos} sharding"
248+
f"Sharding type not supported {sharding_type} for {storage_device_type_from_sharding_infos} sharding"
243249
)
244250

245251

torchrec/distributed/quant_embeddingbag.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
create_sharding_infos_by_sharding_device_group,
3636
)
3737
from torchrec.distributed.fused_params import (
38+
FUSED_PARAM_IS_SSD_TABLE,
3839
FUSED_PARAM_QUANT_STATE_DICT_SPLIT_SCALE_BIAS,
3940
FUSED_PARAM_REGISTER_TBE_BOOL,
4041
get_tbes_to_register_from_iterable,
@@ -97,12 +98,17 @@ def get_device_from_parameter_sharding(
9798
def get_device_from_sharding_infos(
9899
emb_shard_infos: List[EmbeddingShardingInfo],
99100
) -> Union[str, Tuple[str, ...]]:
100-
res = list(
101-
{
102-
get_device_from_parameter_sharding(ps.param_sharding)
103-
for ps in emb_shard_infos
104-
}
105-
)
101+
res_set = set()
102+
for emb_shard_info in emb_shard_infos:
103+
if emb_shard_info.fused_params and emb_shard_info.fused_params.get(
104+
FUSED_PARAM_IS_SSD_TABLE, False
105+
):
106+
res_set.add("ssd")
107+
else:
108+
res_set.add(
109+
get_device_from_parameter_sharding(emb_shard_info.param_sharding)
110+
)
111+
res = list(res_set)
106112
assert len(res) == 1, "All shards should be on the same type of device"
107113
return res[0]
108114

@@ -131,7 +137,7 @@ def create_infer_embedding_bag_sharding(
131137
NullShardingContext, InputDistOutputs, List[torch.Tensor], torch.Tensor
132138
]:
133139
propogate_device: bool = get_propogate_device()
134-
device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
140+
storage_device_type_from_sharding_infos: Union[str, Tuple[str, ...]] = (
135141
get_device_from_sharding_infos(sharding_infos)
136142
)
137143
if sharding_type == ShardingType.TABLE_WISE.value:
@@ -143,7 +149,7 @@ def create_infer_embedding_bag_sharding(
143149
sharding_infos,
144150
env,
145151
device=device if propogate_device else None,
146-
device_type_from_sharding_infos=device_type_from_sharding_infos,
152+
device_type_from_sharding_infos=storage_device_type_from_sharding_infos,
147153
)
148154
elif sharding_type == ShardingType.COLUMN_WISE.value:
149155
return InferCwPooledEmbeddingSharding(

torchrec/distributed/sharding/rw_sequence_sharding.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,9 @@ def forward(
214214
# using _device_type_from_sharding_infos to iterate on local_embs list as
215215
# that's a better practice.
216216
for i, device_type in enumerate(self._device_type_from_sharding_infos):
217+
assert (
218+
device_type != "ssd"
219+
), "Heterogenous sharding across multiple storage device types for a single table not supported for ssd stroage device type"
217220
if device_type != "cpu":
218221
non_cpu_local_embs.append(
219222
_get_batching_hinted_output(
@@ -235,7 +238,7 @@ def forward(
235238
result.append(non_cpu_local_embs_dist[index])
236239
index += 1
237240
return result
238-
elif self._device_type_from_sharding_infos == "cpu":
241+
elif self._device_type_from_sharding_infos in ["cpu", "ssd"]:
239242
# for cpu sharder, output dist should be a no-op
240243
return local_embs
241244
else:

0 commit comments

Comments
 (0)