Skip to content

add load checkpoint support for virtual table #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions torchrec/distributed/batched_embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1452,16 +1452,22 @@ def _init_sharded_split_embedding_weights(
pmt_list,
self._pg,
prefix,
self._table_name_to_weight_count_per_rank,
)
weight_id_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
emb_table_config_copy,
weight_ids_list, # pyre-ignore [6]
self._pg,
prefix,
self._table_name_to_weight_count_per_rank,
)
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
emb_table_config_copy,
# pyre-ignore [6]
bucket_cnt_list,
bucket_cnt_list, # pyre-ignore [6]
self._pg,
prefix,
self._table_name_to_weight_count_per_rank,
use_param_size_as_rows=True,
)
# pyre-ignore
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)
Expand Down
18 changes: 15 additions & 3 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
from torch.distributed._tensor import DTensor
from torch.nn.modules.module import _IncompatibleKeys
from torch.nn.parallel import DistributedDataParallel
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.embedding_lookup import PartiallyMaterializedTensor
Expand Down Expand Up @@ -506,6 +507,7 @@ def __init__(
)
self._need_indices: bool = module.need_indices()
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None
self._skip_missing_weight_key: List[str] = []

for index, (sharding, lookup) in enumerate(
zip(
Expand Down Expand Up @@ -705,9 +707,8 @@ def _pre_load_state_dict_hook(

# for loading state_dict into virtual table, we skip the weights assignment
# if needed, for now this should be handled separately outside of load_state_dict call
state_dict[weight_key] = self._model_parallel_name_to_local_shards[
table_name
][0].tensor
self._skip_missing_weight_key.append(weight_key)
del state_dict[weight_key]
continue

key = f"{prefix}embeddings.{table_name}.weight"
Expand Down Expand Up @@ -1087,11 +1088,22 @@ def update_destination(
virtual_table_sharded_t_map[table_name][1],
)

def _post_load_state_dict_hook(
module: "ShardedEmbeddingCollection",
incompatible_keys: _IncompatibleKeys,
) -> None:
if incompatible_keys.missing_keys:
# has to remove the key inplace
for skip_key in module._skip_missing_weight_key:
if skip_key in incompatible_keys.missing_keys:
incompatible_keys.missing_keys.remove(skip_key)

self.register_state_dict_pre_hook(self._pre_state_dict_hook)
self._register_state_dict_hook(post_state_dict_hook)
self._register_load_state_dict_pre_hook(
self._pre_load_state_dict_hook, with_module=True
)
self.register_load_state_dict_post_hook(_post_load_state_dict_hook)

self.reset_parameters()

Expand Down
83 changes: 54 additions & 29 deletions torchrec/distributed/embedding_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,62 +66,74 @@ def create_virtual_table_local_metadata(
local_metadata: ShardMetadata,
param: Union[torch.Tensor, PartiallyMaterializedTensor],
my_rank: int,
offset: Optional[int] = None,
weight_count_per_rank: Optional[List[int]] = None,
) -> None:
local_metadata.shard_sizes = list(param.size()) # pyre-ignore
local_metadata.shard_offsets = [0 for _ in range(len(param.size()))] # pyre-ignore
if offset is None:
offset = (
my_rank
if weight_count_per_rank is None
else sum(weight_count_per_rank[:my_rank])
)
local_metadata.shard_sizes = list(param.size()) # pyre-ignore[6]
local_metadata.shard_offsets = [
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
]


def create_virtual_table_global_metadata(
metadata: ShardedTensorMetadata,
my_rank: int,
param: Union[torch.Tensor, PartiallyMaterializedTensor],
weight_count_per_rank: Optional[List[int]],
use_param_size_as_rows: bool,
) -> None:
# update tensor properties from local tensor properties, this should be universal for all ranks
metadata.tensor_properties.dtype = param.dtype
metadata.tensor_properties.requires_grad = param.requires_grad

# manually craft metadata, faking the metadata in a way that all other rank only has 0 row
# NOTE this currently only works for row-wise sharding
fake_total_rows = param.size()[0] # pyre-ignore
metadata.size = torch.Size(
[
fake_total_rows if dim == 0 else param.size(dim)
for dim in range(len(param.size())) # pyre-ignore
]
)
offset = 0

for rank, shard_metadata in enumerate(metadata.shards_metadata):
if use_param_size_as_rows: # respect the param size and treat it as rows
curr_rank_rows = param.size()[0] # pyre-ignore[16]
else:
curr_rank_rows = (
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
)
if rank < my_rank:
shard_metadata.shard_sizes = [ # pyre-ignore
0 if dim == 0 else param.size(dim)
# pyre-ignore
for dim in range(len(param.size()))
shard_metadata.shard_sizes = [
curr_rank_rows if dim == 0 else param.size(dim)
for dim in range(len(param.size())) # pyre-ignore[6]
]
shard_metadata.shard_offsets = [
0 for _ in range(len(param.size())) # pyre-ignore
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
]
elif rank == my_rank:
create_virtual_table_local_metadata(shard_metadata, param, my_rank)
curr_rank_rows = param.size()[0] # pyre-ignore[16]
create_virtual_table_local_metadata(shard_metadata, param, my_rank, offset)
else:
# pyre-ignore
shard_metadata.shard_sizes = [
0 if dim == 0 else param.size(dim)
# pyre-ignore
for dim in range(len(param.size()))
curr_rank_rows if dim == 0 else param.size(dim)
for dim in range(len(param.size())) # pyre-ignore[6]
]
# pyre-ignore
shard_metadata.shard_offsets = [
param.size(0) if dim == 0 else 0
# pyre-ignore
for dim in range(len(param.size()))
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
]
offset += curr_rank_rows

metadata.size = torch.Size(
[offset if dim == 0 else param.size(dim) for dim in range(len(param.size()))] # pyre-ignore[6]
)


def create_virtual_sharded_tensors(
embedding_tables: List[ShardedEmbeddingTable],
params: Union[List[torch.Tensor], List[PartiallyMaterializedTensor]],
pg: Optional[dist.ProcessGroup] = None,
prefix: str = "",
table_name_to_weight_count_per_rank: Optional[Dict[str, List[int]]] = None,
use_param_size_as_rows: bool = False,
) -> List[ShardedTensor]:
"""
Create virtual sharded tensors for the given embedding tables and parameters.
Expand All @@ -139,19 +151,32 @@ def create_virtual_sharded_tensors(
def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
return prefix + f"{embedding_table.name}"

def get_weight_count_per_rank(table_name: str) -> Optional[List[int]]:
return (
table_name_to_weight_count_per_rank.get(table_name, None)
if table_name_to_weight_count_per_rank
and table_name in table_name_to_weight_count_per_rank.keys()
else None
)

my_rank = dist.get_rank()
for embedding_table, param in zip(embedding_tables, params):
key = get_key_from_embedding_table(embedding_table)
assert embedding_table.use_virtual_table

assert embedding_table.global_metadata is not None
global_metadata = copy.deepcopy(embedding_table.global_metadata)
create_virtual_table_global_metadata(global_metadata, my_rank, param)
weight_count_per_rank = get_weight_count_per_rank(embedding_table.name)
create_virtual_table_global_metadata(
global_metadata,
my_rank,
param,
weight_count_per_rank,
use_param_size_as_rows,
)
key_to_global_metadata[key] = global_metadata

assert embedding_table.local_metadata is not None
local_metadata = copy.deepcopy(embedding_table.local_metadata)
create_virtual_table_local_metadata(local_metadata, param, my_rank)
local_metadata = copy.deepcopy(global_metadata.shards_metadata[my_rank])

key_to_local_shards[key].append(Shard(param, local_metadata)) # pyre-ignore

Expand Down