Skip to content

Commit 38a9474

Browse files
Jianbo Liufacebook-github-bot
Jianbo Liu
authored andcommitted
add load checkpoint support for virtual table
Summary: X-link: pytorch/FBGEMM#4250 after all of the rebasing and landing, the trunk still missed some of the needed changes for checkpoint loading: * change `create_virtual_table_global_metadata` to respect local_weight_count on each rank, or just use the param size as number of rows on each rank * register register_load_state_dict_post_hook in ShardedEmbeddingCollection to let it ignore loading the weight tensor Differential Revision: D75843542 Privacy Context Container: L1138451
1 parent a5f8103 commit 38a9474

File tree

3 files changed

+78
-35
lines changed

3 files changed

+78
-35
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1452,16 +1452,22 @@ def _init_sharded_split_embedding_weights(
14521452
pmt_list,
14531453
self._pg,
14541454
prefix,
1455+
self._table_name_to_weight_count_per_rank,
14551456
)
14561457
weight_id_sharded_t_list = create_virtual_sharded_tensors(
1457-
emb_table_config_copy, weight_ids_list, self._pg, prefix # pyre-ignore
1458+
emb_table_config_copy,
1459+
weight_ids_list, # pyre-ignore [6]
1460+
self._pg,
1461+
prefix,
1462+
self._table_name_to_weight_count_per_rank,
14581463
)
14591464
bucket_cnt_sharded_t_list = create_virtual_sharded_tensors(
14601465
emb_table_config_copy,
1461-
# pyre-ignore [6]
1462-
bucket_cnt_list,
1466+
bucket_cnt_list, # pyre-ignore [6]
14631467
self._pg,
14641468
prefix,
1469+
self._table_name_to_weight_count_per_rank,
1470+
use_param_size_as_rows=True,
14651471
)
14661472
# pyre-ignore
14671473
assert len(pmt_list) == len(weight_ids_list) == len(bucket_cnt_list)

torchrec/distributed/embedding.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from torch.autograd.profiler import record_function
3131
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
3232
from torch.distributed._tensor import DTensor
33+
from torch.nn.modules.module import _IncompatibleKeys
3334
from torch.nn.parallel import DistributedDataParallel
3435
from torchrec.distributed.comm import get_local_size
3536
from torchrec.distributed.embedding_lookup import PartiallyMaterializedTensor
@@ -506,6 +507,7 @@ def __init__(
506507
)
507508
self._need_indices: bool = module.need_indices()
508509
self._inverse_indices_permute_per_sharding: Optional[List[torch.Tensor]] = None
510+
self._skip_missing_weight_key: List[str] = []
509511

510512
for index, (sharding, lookup) in enumerate(
511513
zip(
@@ -705,9 +707,8 @@ def _pre_load_state_dict_hook(
705707

706708
# for loading state_dict into virtual table, we skip the weights assignment
707709
# if needed, for now this should be handled separately outside of load_state_dict call
708-
state_dict[weight_key] = self._model_parallel_name_to_local_shards[
709-
table_name
710-
][0].tensor
710+
self._skip_missing_weight_key.append(weight_key)
711+
del state_dict[weight_key]
711712
continue
712713

713714
key = f"{prefix}embeddings.{table_name}.weight"
@@ -1087,11 +1088,22 @@ def update_destination(
10871088
virtual_table_sharded_t_map[table_name][1],
10881089
)
10891090

1091+
def _post_load_state_dict_hook(
1092+
module: "ShardedEmbeddingCollection",
1093+
incompatible_keys: _IncompatibleKeys,
1094+
) -> None:
1095+
if incompatible_keys.missing_keys:
1096+
# has to remove the key inplace
1097+
for skip_key in module._skip_missing_weight_key:
1098+
if skip_key in incompatible_keys.missing_keys:
1099+
incompatible_keys.missing_keys.remove(skip_key)
1100+
10901101
self.register_state_dict_pre_hook(self._pre_state_dict_hook)
10911102
self._register_state_dict_hook(post_state_dict_hook)
10921103
self._register_load_state_dict_pre_hook(
10931104
self._pre_load_state_dict_hook, with_module=True
10941105
)
1106+
self.register_load_state_dict_post_hook(_post_load_state_dict_hook)
10951107

10961108
self.reset_parameters()
10971109

torchrec/distributed/embedding_kernel.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -66,62 +66,74 @@ def create_virtual_table_local_metadata(
6666
local_metadata: ShardMetadata,
6767
param: Union[torch.Tensor, PartiallyMaterializedTensor],
6868
my_rank: int,
69+
offset: Optional[int] = None,
70+
weight_count_per_rank: Optional[List[int]] = None,
6971
) -> None:
70-
local_metadata.shard_sizes = list(param.size()) # pyre-ignore
71-
local_metadata.shard_offsets = [0 for _ in range(len(param.size()))] # pyre-ignore
72+
if offset is None:
73+
offset = (
74+
my_rank
75+
if weight_count_per_rank is None
76+
else sum(weight_count_per_rank[:my_rank])
77+
)
78+
local_metadata.shard_sizes = list(param.size()) # pyre-ignore[6]
79+
local_metadata.shard_offsets = [
80+
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
81+
]
7282

7383

7484
def create_virtual_table_global_metadata(
7585
metadata: ShardedTensorMetadata,
7686
my_rank: int,
7787
param: Union[torch.Tensor, PartiallyMaterializedTensor],
88+
weight_count_per_rank: Optional[List[int]],
89+
use_param_size_as_rows: bool,
7890
) -> None:
7991
# update tensor properties from local tensor properties, this should be universal for all ranks
8092
metadata.tensor_properties.dtype = param.dtype
8193
metadata.tensor_properties.requires_grad = param.requires_grad
8294

83-
# manually craft metadata, faking the metadata in a way that all other rank only has 0 row
84-
# NOTE this currently only works for row-wise sharding
85-
fake_total_rows = param.size()[0] # pyre-ignore
86-
metadata.size = torch.Size(
87-
[
88-
fake_total_rows if dim == 0 else param.size(dim)
89-
for dim in range(len(param.size())) # pyre-ignore
90-
]
91-
)
95+
offset = 0
9296

9397
for rank, shard_metadata in enumerate(metadata.shards_metadata):
98+
if use_param_size_as_rows: # respect the param size and treat it as rows
99+
curr_rank_rows = param.size()[0] # pyre-ignore[16]
100+
else:
101+
curr_rank_rows = (
102+
weight_count_per_rank[rank] if weight_count_per_rank is not None else 1
103+
)
94104
if rank < my_rank:
95-
shard_metadata.shard_sizes = [ # pyre-ignore
96-
0 if dim == 0 else param.size(dim)
97-
# pyre-ignore
98-
for dim in range(len(param.size()))
105+
shard_metadata.shard_sizes = [
106+
curr_rank_rows if dim == 0 else param.size(dim)
107+
for dim in range(len(param.size())) # pyre-ignore[6]
99108
]
100109
shard_metadata.shard_offsets = [
101-
0 for _ in range(len(param.size())) # pyre-ignore
110+
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
102111
]
103112
elif rank == my_rank:
104-
create_virtual_table_local_metadata(shard_metadata, param, my_rank)
113+
curr_rank_rows = param.size()[0] # pyre-ignore[16]
114+
create_virtual_table_local_metadata(shard_metadata, param, my_rank, offset)
105115
else:
106-
# pyre-ignore
107116
shard_metadata.shard_sizes = [
108-
0 if dim == 0 else param.size(dim)
109-
# pyre-ignore
110-
for dim in range(len(param.size()))
117+
curr_rank_rows if dim == 0 else param.size(dim)
118+
for dim in range(len(param.size())) # pyre-ignore[6]
111119
]
112-
# pyre-ignore
113120
shard_metadata.shard_offsets = [
114-
param.size(0) if dim == 0 else 0
115-
# pyre-ignore
116-
for dim in range(len(param.size()))
121+
offset if dim == 0 else 0 for dim in range(len(param.size())) # pyre-ignore[6]
117122
]
123+
offset += curr_rank_rows
124+
125+
metadata.size = torch.Size(
126+
[offset if dim == 0 else param.size(dim) for dim in range(len(param.size()))] # pyre-ignore[6]
127+
)
118128

119129

120130
def create_virtual_sharded_tensors(
121131
embedding_tables: List[ShardedEmbeddingTable],
122132
params: Union[List[torch.Tensor], List[PartiallyMaterializedTensor]],
123133
pg: Optional[dist.ProcessGroup] = None,
124134
prefix: str = "",
135+
table_name_to_weight_count_per_rank: Optional[Dict[str, List[int]]] = None,
136+
use_param_size_as_rows: bool = False,
125137
) -> List[ShardedTensor]:
126138
"""
127139
Create virtual sharded tensors for the given embedding tables and parameters.
@@ -139,19 +151,32 @@ def create_virtual_sharded_tensors(
139151
def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
140152
return prefix + f"{embedding_table.name}"
141153

154+
def get_weight_count_per_rank(table_name: str) -> Optional[List[int]]:
155+
return (
156+
table_name_to_weight_count_per_rank.get(table_name, None)
157+
if table_name_to_weight_count_per_rank
158+
and table_name in table_name_to_weight_count_per_rank.keys()
159+
else None
160+
)
161+
142162
my_rank = dist.get_rank()
143163
for embedding_table, param in zip(embedding_tables, params):
144164
key = get_key_from_embedding_table(embedding_table)
145165
assert embedding_table.use_virtual_table
146166

147167
assert embedding_table.global_metadata is not None
148168
global_metadata = copy.deepcopy(embedding_table.global_metadata)
149-
create_virtual_table_global_metadata(global_metadata, my_rank, param)
169+
weight_count_per_rank = get_weight_count_per_rank(embedding_table.name)
170+
create_virtual_table_global_metadata(
171+
global_metadata,
172+
my_rank,
173+
param,
174+
weight_count_per_rank,
175+
use_param_size_as_rows,
176+
)
150177
key_to_global_metadata[key] = global_metadata
151178

152-
assert embedding_table.local_metadata is not None
153-
local_metadata = copy.deepcopy(embedding_table.local_metadata)
154-
create_virtual_table_local_metadata(local_metadata, param, my_rank)
179+
local_metadata = copy.deepcopy(global_metadata.shards_metadata[my_rank])
155180

156181
key_to_local_shards[key].append(Shard(param, local_metadata)) # pyre-ignore
157182

0 commit comments

Comments
 (0)