Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for pruning non-pooled embedding collection features #2816

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# pyre-strict


import copy
import logging
import warnings
Expand Down Expand Up @@ -416,6 +415,7 @@ def __init__(
self.input_features: List[KeyedJaggedTensor] = input_features or []
self.reverse_indices: List[torch.Tensor] = reverse_indices or []
self.seq_vbe_ctx: List[SequenceVBEContext] = seq_vbe_ctx or []
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {}

def record_stream(self, stream: torch.Stream) -> None:
for ctx in self.sharding_contexts:
Expand Down Expand Up @@ -548,6 +548,9 @@ def __init__(
table_name_to_parameter_sharding,
fused_params,
)

self._sharding_types: List[str] = list(sharding_type_to_sharding_infos.keys())

self._sharding_type_to_sharding: Dict[
str,
EmbeddingSharding[
Expand Down Expand Up @@ -1018,14 +1021,25 @@ def _generate_permute_indices_per_feature(
def _create_hash_size_info(
self,
feature_names: List[str],
ctx: Optional[EmbeddingCollectionContext] = None,
) -> None:
feature_index = 0
table_to_unpruned_size_mapping: Optional[Dict[str, int]] = None
if (
ctx is not None
and getattr(ctx, "table_name_to_unpruned_hash_sizes", None)
and len(ctx.table_name_to_unpruned_hash_sizes) > 0
):
table_to_unpruned_size_mapping = ctx.table_name_to_unpruned_hash_sizes
for i, sharding in enumerate(self._sharding_type_to_sharding.values()):
feature_hash_size: List[int] = []
feature_hash_size_lengths: List[int] = []
for table in sharding.embedding_tables():
table_hash_size = [0] * table.num_features()
table_hash_size[-1] = table.num_embeddings
if table_to_unpruned_size_mapping and table.name:
table_hash_size[-1] = table_to_unpruned_size_mapping[table.name]
else:
table_hash_size[-1] = table.num_embeddings
feature_hash_size.extend(table_hash_size)

table_hash_size = [0] * table.num_features()
Expand Down Expand Up @@ -1063,6 +1077,7 @@ def _create_hash_size_info(
def _create_input_dist(
self,
input_feature_names: List[str],
ctx: Optional[EmbeddingCollectionContext] = None,
) -> None:
feature_names: List[str] = []
self._feature_splits: List[int] = []
Expand All @@ -1085,7 +1100,7 @@ def _create_input_dist(
)

if self._use_index_dedup:
self._create_hash_size_info(feature_names)
self._create_hash_size_info(feature_names, ctx)

def _create_lookups(self) -> None:
for sharding in self._sharding_type_to_sharding.values():
Expand Down Expand Up @@ -1225,7 +1240,7 @@ def input_dist(
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._create_input_dist(input_feature_names=features.keys(), ctx=ctx)
self._has_uninitialized_input_dist = False
with torch.no_grad():
unpadded_features = None
Expand Down
257 changes: 255 additions & 2 deletions torchrec/distributed/itep_embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@
from torch import nn
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.embedding import (
EmbeddingCollectionContext,
EmbeddingCollectionSharder,
ShardedEmbeddingCollection,
)

from torchrec.distributed.embedding_types import (
BaseEmbeddingSharder,
Expand All @@ -36,9 +41,12 @@
ShardingType,
)
from torchrec.distributed.utils import filter_state_dict
from torchrec.modules.itep_embedding_modules import ITEPEmbeddingBagCollection
from torchrec.modules.itep_embedding_modules import (
ITEPEmbeddingBagCollection,
ITEPEmbeddingCollection,
)
from torchrec.modules.itep_modules import GenericITEPModule, RowwiseShardedITEPModule
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor, KeyedTensor
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor


@dataclass
Expand Down Expand Up @@ -314,3 +322,248 @@ def module_type(self) -> Type[ITEPEmbeddingBagCollection]:
def sharding_types(self, compute_device_type: str) -> List[str]:
types = list(SHARDING_TYPE_TO_GROUP.keys())
return types


class ITEPEmbeddingCollectionContext(EmbeddingCollectionContext):

def __init__(self) -> None:
super().__init__()
self.is_reindexed: bool = False
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = {}


class ShardedITEPEmbeddingCollection(
ShardedEmbeddingModule[
KJTList,
List[torch.Tensor],
Dict[str, JaggedTensor],
ITEPEmbeddingCollectionContext,
]
):
def __init__(
self,
module: ITEPEmbeddingCollection,
table_name_to_parameter_sharding: Dict[str, ParameterSharding],
ebc_sharder: EmbeddingCollectionSharder,
env: ShardingEnv,
device: torch.device,
) -> None:
super().__init__()

self._device = device
self._env = env
self.table_name_to_unpruned_hash_sizes: Dict[str, int] = (
module._itep_module.table_name_to_unpruned_hash_sizes
)

# Iteration counter for ITEP Module. Pinning on CPU because used for condition checking and checkpointing.
self.register_buffer(
"_iter", torch.tensor(0, dtype=torch.int64, device=torch.device("cpu"))
)

self._embedding_collection: ShardedEmbeddingCollection = ebc_sharder.shard(
module._embedding_collection,
table_name_to_parameter_sharding,
env=env,
device=device,
)

self.table_name_to_sharding_type: Dict[str, str] = {}
for table_name in table_name_to_parameter_sharding.keys():
self.table_name_to_sharding_type[table_name] = (
table_name_to_parameter_sharding[table_name].sharding_type
)

# Group lookups, table_name_to_unpruned_hash_sizes by sharding type and pass to separate itep modules
(grouped_lookups, grouped_table_unpruned_size_map) = (
self._group_lookups_and_table_unpruned_size_map(
module._itep_module.table_name_to_unpruned_hash_sizes,
)
)

# Instantiate ITEP Module in sharded case, re-using metadata from non-sharded case
self._itep_module: GenericITEPModule = GenericITEPModule(
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
ShardingTypeGroup.CW_GROUP
],
lookups=grouped_lookups[ShardingTypeGroup.CW_GROUP],
pruning_interval=module._itep_module.pruning_interval,
enable_pruning=module._itep_module.enable_pruning,
)
self._rowwise_itep_module: RowwiseShardedITEPModule = RowwiseShardedITEPModule(
table_name_to_unpruned_hash_sizes=grouped_table_unpruned_size_map[
ShardingTypeGroup.RW_GROUP
],
lookups=grouped_lookups[ShardingTypeGroup.RW_GROUP],
pruning_interval=module._itep_module.pruning_interval,
table_name_to_sharding_type=self.table_name_to_sharding_type,
enable_pruning=module._itep_module.enable_pruning,
)

# pyre-ignore
def input_dist(
self,
ctx: ITEPEmbeddingCollectionContext,
features: KeyedJaggedTensor,
force_insert: bool = False,
) -> Awaitable[Awaitable[KJTList]]:

ctx.table_name_to_unpruned_hash_sizes = self.table_name_to_unpruned_hash_sizes
return self._embedding_collection.input_dist(ctx, features)

def compute(
self,
ctx: ITEPEmbeddingCollectionContext,
dist_input: KJTList,
) -> List[torch.Tensor]:
for i, (sharding, features) in enumerate(
zip(
self._embedding_collection._sharding_type_to_sharding.keys(),
dist_input,
)
):
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
remapped_kjt = self._itep_module(features, self._iter.item())
else:
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
dist_input[i] = remapped_kjt
self._iter += 1
return self._embedding_collection.compute(ctx, dist_input)

def output_dist(
self,
ctx: ITEPEmbeddingCollectionContext,
output: List[torch.Tensor],
) -> LazyAwaitable[Dict[str, JaggedTensor]]:

ec_awaitable = self._embedding_collection.output_dist(ctx, output)
return ec_awaitable

def compute_and_output_dist(
self, ctx: ITEPEmbeddingCollectionContext, input: KJTList
) -> LazyAwaitable[Dict[str, JaggedTensor]]:
# Insert forward() function of GenericITEPModule into compute_and_output_dist()
""" """
for i, (sharding, features) in enumerate(
zip(
self._embedding_collection._sharding_type_to_sharding.keys(),
input,
)
):
if SHARDING_TYPE_TO_GROUP[sharding] == ShardingTypeGroup.CW_GROUP:
remapped_kjt = self._itep_module(features, self._iter.item())
else:
remapped_kjt = self._rowwise_itep_module(features, self._iter.item())
input[i] = remapped_kjt
self._iter += 1
ec_awaitable = self._embedding_collection.compute_and_output_dist(ctx, input)
return ec_awaitable

def create_context(self) -> ITEPEmbeddingCollectionContext:
return ITEPEmbeddingCollectionContext()

# pyre-fixme[14]: `load_state_dict` overrides method defined in `Module`
# inconsistently.
def load_state_dict(
self,
state_dict: "OrderedDict[str, torch.Tensor]",
strict: bool = True,
) -> _IncompatibleKeys:
missing_keys = []
unexpected_keys = []
self._iter = state_dict["_iter"]
for name, child_module in self._modules.items():
if child_module is not None:
missing, unexpected = child_module.load_state_dict(
filter_state_dict(state_dict, name),
strict,
)
missing_keys.extend(missing)
unexpected_keys.extend(unexpected)
return _IncompatibleKeys(
missing_keys=missing_keys, unexpected_keys=unexpected_keys
)

def _group_lookups_and_table_unpruned_size_map(
self, table_name_to_unpruned_hash_sizes: Dict[str, int]
) -> Tuple[
Dict[ShardingTypeGroup, List[nn.Module]],
Dict[ShardingTypeGroup, Dict[str, int]],
]:
"""
Group ebc lookups and table_name_to_unpruned_hash_sizes by sharding types.
CW and TW are grouped into CW_GROUP, RW and TWRW are grouped into RW_GROUP.

Return a tuple of (grouped_lookups, grouped _table_unpruned_size_map)
"""
grouped_lookups: Dict[ShardingTypeGroup, List[nn.Module]] = defaultdict(list)
grouped_table_unpruned_size_map: Dict[ShardingTypeGroup, Dict[str, int]] = (
defaultdict(dict)
)
for sharding_type, lookup in zip(
self._embedding_collection._sharding_types,
self._embedding_collection._lookups,
):
sharding_group = SHARDING_TYPE_TO_GROUP[sharding_type]
# group lookups
grouped_lookups[sharding_group].append(lookup)
# group table_name_to_unpruned_hash_sizes
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
for emb_config in lookup.grouped_configs:
for table in emb_config.embedding_tables:
if table.name in table_name_to_unpruned_hash_sizes.keys():
grouped_table_unpruned_size_map[sharding_group][table.name] = (
table_name_to_unpruned_hash_sizes[table.name]
)

return grouped_lookups, grouped_table_unpruned_size_map


class ITEPEmbeddingCollectionSharder(BaseEmbeddingSharder[ITEPEmbeddingCollection]):
def __init__(
self,
ebc_sharder: Optional[EmbeddingCollectionSharder] = None,
qcomm_codecs_registry: Optional[Dict[str, QuantizedCommCodecs]] = None,
) -> None:
super().__init__(qcomm_codecs_registry=qcomm_codecs_registry)
self._ebc_sharder: EmbeddingCollectionSharder = (
ebc_sharder
or EmbeddingCollectionSharder(
qcomm_codecs_registry=self.qcomm_codecs_registry
)
)

def shard(
self,
module: ITEPEmbeddingCollection,
params: Dict[str, ParameterSharding],
env: ShardingEnv,
device: Optional[torch.device] = None,
module_fqn: Optional[str] = None,
) -> ShardedITEPEmbeddingCollection:

# Enforce GPU for ITEPEmbeddingBagCollection
if device is None:
device = torch.device("cuda")

return ShardedITEPEmbeddingCollection(
module,
params,
ebc_sharder=self._ebc_sharder,
env=env,
device=device,
)

def shardable_parameters(
self, module: ITEPEmbeddingCollection
) -> Dict[str, torch.nn.Parameter]:
return self._ebc_sharder.shardable_parameters(module._embedding_collection)

@property
def module_type(self) -> Type[ITEPEmbeddingCollection]:
return ITEPEmbeddingCollection

def sharding_types(self, compute_device_type: str) -> List[str]:
types = list(SHARDING_TYPE_TO_GROUP.keys())
return types
Loading
Loading