Skip to content

Commit 3939474

Browse files
Hanna Xufacebook-github-bot
Hanna Xu
authored andcommitted
(1/n) Support DI sharding for FPE_EBC (pytorch#2968)
Summary: Support models that have FeatureProcessedEmbeddingBagCollection. These changes make sure we add QuantFeatureProcessedEmbeddingBagCollectionSharder as a recognized sharder, handle multiple envs needed for specifying DI sharding, and propagate TBE properly when processing the sharding plan. This doesn't support true hybrid sharding yet, so FPE_EBCs must be all sharded with the same (sharding_type, device). Differential Revision: D74671655
1 parent b0919ce commit 3939474

File tree

2 files changed

+42
-8
lines changed

2 files changed

+42
-8
lines changed

torchrec/distributed/quant_embeddingbag.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ def __init__(
430430
self,
431431
module: EmbeddingBagCollectionInterface,
432432
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
433-
env: ShardingEnv,
433+
env: Union[ShardingEnv, Dict[str, ShardingEnv]], # support for hybrid sharding
434434
fused_params: Optional[Dict[str, Any]] = None,
435435
device: Optional[torch.device] = None,
436436
feature_processor: Optional[FeatureProcessorsCollection] = None,
@@ -462,11 +462,33 @@ def __init__(
462462
f"Feature processor has inconsistent devices. Expected {feature_processor_device}, got {param.device}"
463463
)
464464

465+
world_sizes = []
466+
if isinstance(env, Dict):
467+
for (
468+
embedding_configs
469+
) in self._sharding_type_device_group_to_sharding_infos.values():
470+
world_sizes.append(
471+
# ensures that the same device is used for this sharding type
472+
env[
473+
get_device_for_first_shard_from_sharding_infos(
474+
embedding_configs
475+
)
476+
].world_size
477+
)
478+
else:
479+
world_sizes.append(env.world_size)
480+
481+
# TODO(hcxu): fully support hybrid sharding with feature_processors_per_rank: ModuleList(ModuleList())
482+
assert (
483+
len(world_sizes) == 1
484+
), "Sharding across multiple (sharding type, device type) for FeatureProcessedEmbeddingBagCollection is not supported yet"
485+
486+
total_world_size = world_sizes[-1]
465487
if feature_processor_device is None:
466-
for _ in range(env.world_size):
488+
for _ in range(total_world_size):
467489
self.feature_processors_per_rank.append(feature_processor)
468490
else:
469-
for i in range(env.world_size):
491+
for i in range(total_world_size):
470492
# Generic copy, for example initailized on cpu but -> sharding as meta
471493
self.feature_processors_per_rank.append(
472494
copy.deepcopy(feature_processor)

torchrec/distributed/quant_state.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,16 @@ def get_bucket_offsets_per_virtual_table(
415415
}
416416

417417

418+
def get_param_id_from_type(is_sqebc: bool, is_sqmcec: bool, is_sfpebc: bool) -> str:
419+
if is_sqebc:
420+
return "embedding_bags"
421+
elif is_sqmcec:
422+
return "_embedding_module.embeddings"
423+
elif is_sfpebc:
424+
return "_embedding_bag_collection.embedding_bags"
425+
return "embeddings"
426+
427+
418428
def sharded_tbes_weights_spec(
419429
sharded_model: torch.nn.Module,
420430
virtual_table_name_to_bucket_lengths: Optional[Dict[str, list[int]]] = None,
@@ -450,11 +460,14 @@ def sharded_tbes_weights_spec(
450460
is_sqebc: bool = "ShardedQuantEmbeddingBagCollection" in type_name
451461
is_sqec: bool = "ShardedQuantEmbeddingCollection" in type_name
452462
is_sqmcec: bool = "ShardedQuantManagedCollisionEmbeddingCollection" in type_name
463+
is_sfpebc: bool = (
464+
"ShardedQuantFeatureProcessedEmbeddingBagCollection" in type_name
465+
)
453466

454-
if is_sqebc or is_sqec or is_sqmcec:
467+
if is_sqebc or is_sqec or is_sqmcec or is_sfpebc:
455468
assert (
456-
is_sqec + is_sqebc + is_sqmcec == 1
457-
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection and ShardedQuantManagedCollisionEmbeddingCollection are true"
469+
is_sqec + is_sqebc + is_sqmcec + is_sfpebc == 1
470+
), "Cannot have any two of ShardedQuantEmbeddingBagCollection, ShardedQuantEmbeddingCollection, ShardedQuantManagedCollisionEmbeddingCollection and ShardedQuantFeatureProcessedEmbeddingBagCollection are true"
458471
tbes_configs: Dict[
459472
IntNBitTableBatchedEmbeddingBagsCodegen, GroupedEmbeddingConfig
460473
] = module.tbes_configs()
@@ -546,8 +559,7 @@ def sharded_tbes_weights_spec(
546559
row_offsets,
547560
table_metadata.shard_offsets[1],
548561
]
549-
s: str = "embedding_bags" if is_sqebc else "embeddings"
550-
s = ("_embedding_module." if is_sqmcec else "") + s
562+
s: str = get_param_id_from_type(is_sqebc, is_sqmcec, is_sfpebc)
551563
unsharded_fqn_weight_prefix: str = f"{module_fqn}.{s}.{table_name}"
552564
unsharded_fqn_weight: str = unsharded_fqn_weight_prefix + ".weight"
553565

0 commit comments

Comments
 (0)