Skip to content

Commit f4b70a0

Browse files
Hanna Xufacebook-github-bot
Hanna Xu
authored andcommitted
(1/n) Allow DI sharding for models with FPE_EBC (#2968)
Summary: Pull Request resolved: #2968 Support models that have FeatureProcessedEmbeddingBagCollection to be DI sharded. However, conservatively enforce that FPE itself can only be sharded on HBM and not across CPU as well. These changes make sure we add QuantFeatureProcessedEmbeddingBagCollectionSharder as a recognized sharder, handle multiple envs needed for specifying DI sharding, and propagate TBE properly when processing the sharding plan. This doesn't support true hybrid sharding for FPE. Differential Revision: D74671655
1 parent d9254e2 commit f4b70a0

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed

torchrec/distributed/quant_embeddingbag.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def __init__(
430430
self,
431431
module: EmbeddingBagCollectionInterface,
432432
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
433-
env: ShardingEnv,
433+
env: Union[ShardingEnv, Dict[str, ShardingEnv]], # support for hybrid sharding
434434
fused_params: Optional[Dict[str, Any]] = None,
435435
device: Optional[torch.device] = None,
436436
feature_processor: Optional[FeatureProcessorsCollection] = None,
@@ -462,11 +462,32 @@ def __init__(
462462
f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}"
463463
)
464464

465+
if isinstance(env, Dict):
466+
expected_device_type = "cuda"
467+
for (
468+
embedding_configs
469+
) in self._sharding_type_device_group_to_sharding_infos.values():
470+
# throws if not all shards only have the expected device type
471+
shard_device_type = get_device_from_sharding_infos(embedding_configs)
472+
if isinstance(shard_device_type, tuple):
473+
assert (
474+
len(set(shard_device_type)) == 1,
475+
f"Sharding across multiple device types for FeatureProcessedEmbeddingBagCollection is not supported yet, got {shard_device_type}",
476+
)
477+
shard_device_type = shard_device_type[0]
478+
assert (
479+
shard_device_type == expected_device_type
480+
), f"Expected {expected_device_type} but got {shard_device_type} for FeatureProcessedEmbeddingBagCollection sharding device type"
481+
482+
# TODO(hcxu): support hybrid sharding with feature_processors_per_rank: ModuleList(ModuleList()), if compatible
483+
world_size = env[expected_device_type].world_size
484+
else:
485+
world_size = env.world_size
486+
465487
if feature_processor_device is None:
466-
for _ in range(env.world_size):
467-
self.feature_processors_per_rank.append(feature_processor)
488+
self.feature_processors_per_rank += [feature_processor] * world_size
468489
else:
469-
for i in range(env.world_size):
490+
for i in range(world_size):
470491
# Generic copy, for example initailized on cpu but -> sharding as meta
471492
self.feature_processors_per_rank.append(
472493
copy.deepcopy(feature_processor)

torchrec/distributed/quant_state.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,16 @@ def get_bucket_offsets_per_virtual_table(
419419
}
420420

421421

422+
def get_param_id_from_type(is_sqebc: bool, is_sqmcec: bool, is_sfpebc: bool) -> str:
423+
if is_sqebc:
424+
return "embedding_bags"
425+
elif is_sqmcec:
426+
return "_embedding_module.embeddings"
427+
elif is_sfpebc:
428+
return "_embedding_bag_collection.embedding_bags"
429+
return "embeddings"
430+
431+
422432
def sharded_tbes_weights_spec(
423433
sharded_model: torch.nn.Module,
424434
virtual_table_name_to_bucket_lengths: Optional[Dict[str, list[int]]] = None,
@@ -454,11 +464,14 @@ def sharded_tbes_weights_spec(
454464
is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name
455465
is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name
456466
is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name
467+
is_sfpebc: bool = (
468+
"ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name
469+
)
457470

458-
if is_sqebc or is_sqec or is_sqmcec:
471+
if is_sqebc or is_sqec or is_sqmcec or is_sfpebc:
459472
assert (
460-
is_sqec + is_sqebc + is_sqmcec == 1
461-
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true"
473+
is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1
474+
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true"
462475
tbes_configs: Dict[
463476
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
464477
] = module.tbes_configs()
@@ -550,8 +563,7 @@ def sharded_tbes_weights_spec(
550563
row_offsets,
551564
table_metadata.shard_offsets[1],
552565
]
553-
s: str = "embedding_bags" if is_sqebc else "embeddings"
554-
s = ("_embedding_module." if is_sqmcec else "") + s
566+
s: str = get_param_id_from_type(is_sqebc, is_sqmcec, is_sfpebc)
555567
unsharded_fqn_weight_prefix: str = f"{module_fqn}.{s}.{table_name}"
556568
unsharded_fqn_weight: str = unsharded_fqn_weight_prefix + ".weight"
557569

0 commit comments

Comments
 (0)