Skip to content

Commit d2c2745

Browse files
Anish Khazanefacebook-github-bot
Anish Khazane
authored andcommitted
Add support for pruning non-pooled embedding collection features (#2816)
Summary: Adds ITEPEmbeddingCollectionSharder to properly prune non-pooled embedding tables. Reviewed By: doIIarplus Differential Revision: D71022806
1 parent 76446e7 commit d2c2745

File tree

3 files changed

+343
-8
lines changed

3 files changed

+343
-8
lines changed

torchrec/distributed/embedding.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
# pyre-strict
99

10-
1110
import copy
1211
import logging
1312
import warnings
@@ -416,6 +415,7 @@ def __init__(
416415
self.input_features: List[KeyedJaggedTensor] = input_features or []
417416
self.reverse_indices: List[torch.Tensor] = reverse_indices or []
418417
self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or []
418+
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {}
419419

420420
def record_stream(self, stream: torch.Stream) -> None:
421421
for ctx in self.sharding_contexts:
@@ -548,6 +548,9 @@ def __init__(
548548
table_name_to_parameter_sharding,
549549
fused_params,
550550
)
551+
552+
self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())
553+
551554
self._sharding_type_to_sharding: Dict[
552555
str,
553556
EmbeddingSharding[
@@ -1018,14 +1021,25 @@ def _generate_permute_indices_per_feature(
10181021
def _create_hash_size_info(
10191022
self,
10201023
feature_names: List[str],
1024+
ctx: Optional[EmbeddingCollectionContext] = None,
10211025
) -> None:
10221026
feature_index = 0
1027+
table_to_unpruned_size_mapping: Optional[Dict[str, int]] = None
1028+
if (
1029+
ctx is not None
1030+
and getattr(ctx, "table_name_to_unpruned_hash_sizes", None)
1031+
and len(ctx.table_name_to_unpruned_hash_sizes) > 0
1032+
):
1033+
table_to_unpruned_size_mapping = ctx.table_name_to_unpruned_hash_sizes
10231034
for i, sharding in enumerate(self._sharding_type_to_sharding.values()):
10241035
feature_hash_size: List[int] = []
10251036
feature_hash_size_lengths: List[int] = []
10261037
for table in sharding.embedding_tables():
10271038
table_hash_size = [0] * table.num_features()
1028-
table_hash_size[-1] = table.num_embeddings
1039+
if table_to_unpruned_size_mapping and table.name:
1040+
table_hash_size[-1] = table_to_unpruned_size_mapping[table.name]
1041+
else:
1042+
table_hash_size[-1] = table.num_embeddings
10291043
feature_hash_size.extend(table_hash_size)
10301044

10311045
table_hash_size = [0] * table.num_features()
@@ -1063,6 +1077,7 @@ def _create_hash_size_info(
10631077
def _create_input_dist(
10641078
self,
10651079
input_feature_names: List[str],
1080+
ctx: Optional[EmbeddingCollectionContext] = None,
10661081
) -> None:
10671082
feature_names: List[str] = []
10681083
self._feature_splits: List[int] = []
@@ -1085,7 +1100,7 @@ def _create_input_dist(
10851100
)
10861101

10871102
if self._use_index_dedup:
1088-
self._create_hash_size_info(feature_names)
1103+
self._create_hash_size_info(feature_names, ctx)
10891104

10901105
def _create_lookups(self) -> None:
10911106
for sharding in self._sharding_type_to_sharding.values():
@@ -1225,7 +1240,7 @@ def input_dist(
12251240
need_permute = False
12261241
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
12271242
if self._has_uninitialized_input_dist:
1228-
self._create_input_dist(input_feature_names=features.keys())
1243+
self._create_input_dist(input_feature_names=features.keys(), ctx=ctx)
12291244
self._has_uninitialized_input_dist = False
12301245
with torch.no_grad():
12311246
unpadded_features = None

torchrec/distributed/itep_embeddingbag.py

+255-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@
1616
from torch import nn
1717
from torch.nn.modules.module import _IncompatibleKeys
1818
from torch.nn.parallel import DistributedDataParallel
19+
from torchrec.distributed.embedding import (
20+
EmbeddingCollectionContext,
21+
EmbeddingCollectionSharder,
22+
ShardedEmbeddingCollection,
23+
)
1924

2025
from torchrec.distributed.embedding_types import (
2126
BaseEmbeddingSharder,
@@ -36,9 +41,12 @@
3641
ShardingType,
3742
)
3843
from torchrec.distributed.utils import filter_state_dict
39-
from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection
44+
from torchrec.modules.itep_embedding_modules import (
45+
ITEPEmbeddingBagCollection,
46+
ITEPEmbeddingCollection,
47+
)
4048
from torchrec.modules.itep_modules import GenericITEPModule, RowwiseShardedITEPModule
41-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
49+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor
4250

4351

4452
@dataclass
@@ -314,3 +322,248 @@ def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
314322
def sharding_types(self, compute_device_type: str) -> List[str]:
315323
types = list(SHARDING_TYPE_TO_GROUP.keys())
316324
return types
325+
326+
327+
class ITEPEmbeddingCollectionContext(EmbeddingCollectionContext):
328+
329+
def __init__(self) -> None:
330+
super().__init__()
331+
self.is_reindexed: bool = False
332+
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {}
333+
334+
335+
class ShardedITEPEmbeddingCollection(
336+
ShardedEmbeddingModule[
337+
KJTList,
338+
List[torch.Tensor],
339+
Dict[str, JaggedTensor],
340+
ITEPEmbeddingCollectionContext,
341+
]
342+
):
343+
def __init__(
344+
self,
345+
module: ITEPEmbeddingCollection,
346+
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
347+
ebc_sharder: EmbeddingCollectionSharder,
348+
env: ShardingEnv,
349+
device: torch.device,
350+
) -> None:
351+
super().__init__()
352+
353+
self._device = device
354+
self._env = env
355+
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = (
356+
module._itep_module.table_name_to_unpruned_hash_sizes
357+
)
358+
359+
# Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing.
360+
self.register_buffer(
361+
"_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu"))
362+
)
363+
364+
self._embedding_collection: ShardedEmbeddingCollection = ebc_sharder.shard(
365+
module._embedding_collection,
366+
table_name_to_parameter_sharding,
367+
env=env,
368+
device=device,
369+
)
370+
371+
self.table_name_to_sharding_type: Dict[str, str] = {}
372+
for table_name in table_name_to_parameter_sharding.keys():
373+
self.table_name_to_sharding_type[table_name] = (
374+
table_name_to_parameter_sharding[table_name].sharding_type
375+
)
376+
377+
# Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules
378+
(grouped_lookups, grouped_table_unpruned_size_map) = (
379+
self._group_lookups_and_table_unpruned_size_map(
380+
module._itep_module.table_name_to_unpruned_hash_sizes,
381+
)
382+
)
383+
384+
# Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case
385+
self._itep_module: GenericITEPModule = GenericITEPModule(
386+
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
387+
ShardingTypeGroup.CW_GROUP
388+
],
389+
lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP],
390+
pruning_interval=module._itep_module.pruning_interval,
391+
enable_pruning=module._itep_module.enable_pruning,
392+
)
393+
self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule(
394+
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
395+
ShardingTypeGroup.RW_GROUP
396+
],
397+
lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP],
398+
pruning_interval=module._itep_module.pruning_interval,
399+
table_name_to_sharding_type=self.table_name_to_sharding_type,
400+
enable_pruning=module._itep_module.enable_pruning,
401+
)
402+
403+
# pyre-ignore
404+
def input_dist(
405+
self,
406+
ctx: ITEPEmbeddingCollectionContext,
407+
features: KeyedJaggedTensor,
408+
force_insert: bool = False,
409+
) -> Awaitable[Awaitable[KJTList]]:
410+
411+
ctx.table_name_to_unpruned_hash_sizes = self.table_name_to_unpruned_hash_sizes
412+
return self._embedding_collection.input_dist(ctx, features)
413+
414+
def compute(
415+
self,
416+
ctx: ITEPEmbeddingCollectionContext,
417+
dist_input: KJTList,
418+
) -> List[torch.Tensor]:
419+
for i, (sharding, features) in enumerate(
420+
zip(
421+
self._embedding_collection._sharding_type_to_sharding.keys(),
422+
dist_input,
423+
)
424+
):
425+
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
426+
remapped_kjt = self._itep_module(features, self._iter.item())
427+
else:
428+
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
429+
dist_input[i] = remapped_kjt
430+
self._iter += 1
431+
return self._embedding_collection.compute(ctx, dist_input)
432+
433+
def output_dist(
434+
self,
435+
ctx: ITEPEmbeddingCollectionContext,
436+
output: List[torch.Tensor],
437+
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
438+
439+
ec_awaitable = self._embedding_collection.output_dist(ctx, output)
440+
return ec_awaitable
441+
442+
def compute_and_output_dist(
443+
self, ctx: ITEPEmbeddingCollectionContext, input: KJTList
444+
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
445+
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
446+
""" """
447+
for i, (sharding, features) in enumerate(
448+
zip(
449+
self._embedding_collection._sharding_type_to_sharding.keys(),
450+
input,
451+
)
452+
):
453+
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
454+
remapped_kjt = self._itep_module(features, self._iter.item())
455+
else:
456+
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
457+
input[i] = remapped_kjt
458+
self._iter += 1
459+
ec_awaitable = self._embedding_collection.compute_and_output_dist(ctx, input)
460+
return ec_awaitable
461+
462+
def create_context(self) -> ITEPEmbeddingCollectionContext:
463+
return ITEPEmbeddingCollectionContext()
464+
465+
# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
466+
# inconsistently.
467+
def load_state_dict(
468+
self,
469+
state_dict: "OrderedDict[str, torch.Tensor]",
470+
strict: bool = True,
471+
) -> _IncompatibleKeys:
472+
missing_keys = []
473+
unexpected_keys = []
474+
self._iter = state_dict["_iter"]
475+
for name, child_module in self._modules.items():
476+
if child_module is not None:
477+
missing, unexpected = child_module.load_state_dict(
478+
filter_state_dict(state_dict, name),
479+
strict,
480+
)
481+
missing_keys.extend(missing)
482+
unexpected_keys.extend(unexpected)
483+
return _IncompatibleKeys(
484+
missing_keys=missing_keys, unexpected_keys=unexpected_keys
485+
)
486+
487+
def _group_lookups_and_table_unpruned_size_map(
488+
self, table_name_to_unpruned_hash_sizes: Dict[str, int]
489+
) -> Tuple[
490+
Dict[ShardingTypeGroup, List[nn.Module]],
491+
Dict[ShardingTypeGroup, Dict[str, int]],
492+
]:
493+
"""
494+
Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types.
495+
CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP.
496+
497+
Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map)
498+
"""
499+
grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list)
500+
grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = (
501+
defaultdict(dict)
502+
)
503+
for sharding_type, lookup in zip(
504+
self._embedding_collection._sharding_types,
505+
self._embedding_collection._lookups,
506+
):
507+
sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type]
508+
# group lookups
509+
grouped_lookups[sharding_group].append(lookup)
510+
# group table_name_to_unpruned_hash_sizes
511+
while isinstance(lookup, DistributedDataParallel):
512+
lookup = lookup.module
513+
for emb_config in lookup.grouped_configs:
514+
for table in emb_config.embedding_tables:
515+
if table.name in table_name_to_unpruned_hash_sizes.keys():
516+
grouped_table_unpruned_size_map[sharding_group][table.name] = (
517+
table_name_to_unpruned_hash_sizes[table.name]
518+
)
519+
520+
return grouped_lookups, grouped_table_unpruned_size_map
521+
522+
523+
class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]):
524+
def __init__(
525+
self,
526+
ebc_sharder: Optional[EmbeddingCollectionSharder] = None,
527+
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
528+
) -> None:
529+
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
530+
self._ebc_sharder: EmbeddingCollectionSharder = (
531+
ebc_sharder
532+
or EmbeddingCollectionSharder(
533+
qcomm_codecs_registry=self.qcomm_codecs_registry
534+
)
535+
)
536+
537+
def shard(
538+
self,
539+
module: ITEPEmbeddingCollection,
540+
params: Dict[str, ParameterSharding],
541+
env: ShardingEnv,
542+
device: Optional[torch.device] = None,
543+
module_fqn: Optional[str] = None,
544+
) -> ShardedITEPEmbeddingCollection:
545+
546+
# Enforce GPU for ITEPEmbeddingBagCollection
547+
if device is None:
548+
device = torch.device("cuda")
549+
550+
return ShardedITEPEmbeddingCollection(
551+
module,
552+
params,
553+
ebc_sharder=self._ebc_sharder,
554+
env=env,
555+
device=device,
556+
)
557+
558+
def shardable_parameters(
559+
self, module: ITEPEmbeddingCollection
560+
) -> Dict[str, torch.nn.Parameter]:
561+
return self._ebc_sharder.shardable_parameters(module._embedding_collection)
562+
563+
@property
564+
def module_type(self) -> Type[ITEPEmbeddingCollection]:
565+
return ITEPEmbeddingCollection
566+
567+
def sharding_types(self, compute_device_type: str) -> List[str]:
568+
types = list(SHARDING_TYPE_TO_GROUP.keys())
569+
return types

0 commit comments

Comments
 (0)