Skip to content

Commit cd1af80

Browse files
emlinfacebook-github-bot
authored andcommitted
Connect optimizer state_dict with computation kernel for checkpoint saving (pytorch#2975)
Summary: Add ZeroCollisionKeyValueEmbeddingFusedOptimizer class for integrating ZCH optimizer state to checkpointing Reviewed By: bobbyliujb Differential Revision: D74790135
1 parent ef5f978 commit cd1af80

File tree

1 file changed

+113
-4
lines changed

1 file changed

+113
-4
lines changed

torchrec/distributed/batched_embedding_kernel.py

Lines changed: 113 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from torchrec.distributed.embedding_kernel import (
5656
BaseEmbedding,
5757
create_virtual_sharded_tensors,
58+
create_virtual_table_local_metadata,
5859
get_state_dict,
5960
)
6061
from torchrec.distributed.embedding_types import (
@@ -206,7 +207,9 @@ def _populate_zero_collision_tbe_params(
206207
bucket_sizes: List[int] = [size for _, _, size in sharded_local_buckets]
207208

208209
tbe_params["kv_zch_params"] = KVZCHParams(
209-
bucket_offsets=bucket_offsets, bucket_sizes=bucket_sizes
210+
bucket_offsets=bucket_offsets,
211+
bucket_sizes=bucket_sizes,
212+
enable_optimizer_offloading=False,
210213
)
211214

212215

@@ -283,6 +286,53 @@ def __init__( # noqa C901
283286
table_name_to_weight_count_per_rank
284287
)
285288

289+
# pyre-ignore [33]
290+
state: Dict[Any, Any] = {}
291+
param_group: Dict[str, Any] = {
292+
"params": [],
293+
"lr": emb_module.get_learning_rate(),
294+
}
295+
296+
params: Dict[str, Union[torch.Tensor, ShardedTensor]] = {}
297+
298+
sorted_id_tensors = (
299+
[
300+
sharded_t._local_shards[0].tensor
301+
for sharded_t in self._sharded_embedding_weight_ids
302+
]
303+
if self._sharded_embedding_weight_ids is not None
304+
else None
305+
)
306+
307+
all_optimizer_states = emb_module.get_optimizer_state(
308+
sorted_id_tensor=sorted_id_tensors
309+
)
310+
opt_param_list = [param["momentum1"] for param in all_optimizer_states]
311+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
312+
for emb_table in emb_table_config_copy:
313+
emb_table.local_metadata.placement._device = torch.device("cpu")
314+
opt_sharded_t_list = create_virtual_sharded_tensors(
315+
emb_table_config_copy, opt_param_list, self._pg
316+
)
317+
318+
for (
319+
emb_config,
320+
sharded_weight,
321+
opt_sharded_t,
322+
) in zip(
323+
emb_table_config_copy,
324+
sharded_embedding_weights_by_table,
325+
opt_sharded_t_list,
326+
):
327+
param_key = emb_config.name + ".weight"
328+
state[sharded_weight] = {}
329+
param_group["params"].append(sharded_weight)
330+
params[param_key] = sharded_weight
331+
332+
state[sharded_weight][f"{emb_config.name}.momentum1"] = opt_sharded_t
333+
334+
super().__init__(params, state, [param_group])
335+
286336
def zero_grad(self, set_to_none: bool = False) -> None:
287337
# pyre-ignore [16]
288338
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])
@@ -292,6 +342,61 @@ def step(self, closure: Any = None) -> None:
292342
# pyre-ignore [16]
293343
self._emb_module.set_learning_rate(self.param_groups[0]["lr"])
294344

345+
def set_sharded_embedding_weight_ids(
346+
self, sharded_embedding_weight_ids: Optional[List[ShardedTensor]]
347+
) -> None:
348+
self._sharded_embedding_weight_ids = sharded_embedding_weight_ids
349+
350+
def _post_state_dict_hook(self, curr_state: Dict[str, Any]) -> None:
351+
logger.info("update optimizer state dict in state_dict_post_hook")
352+
embedding_weight_ids = (
353+
[
354+
sharded_t._local_shards[0].tensor
355+
for sharded_t in self._sharded_embedding_weight_ids
356+
]
357+
if self._sharded_embedding_weight_ids is not None
358+
else None
359+
)
360+
all_optimizer_states = self._emb_module.get_optimizer_state(
361+
embedding_weight_ids,
362+
no_snapshot=False,
363+
should_flush=False, # get embedding weights already flushed, no need to flush again here
364+
)
365+
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
366+
for emb_table in emb_table_config_copy:
367+
emb_table.local_metadata.placement._device = torch.device("cpu")
368+
369+
# The order of table_config is determined so put it as outer-loop for consistent traverse order across ranks
370+
for table_config, opt_states in zip(
371+
emb_table_config_copy,
372+
all_optimizer_states,
373+
):
374+
for key, sharded_t_dict in curr_state.items():
375+
# update zero collision table's optimizer state
376+
if f".{table_config.name}.weight" in key:
377+
for (_, opt_state_t), (sharded_t_k, sharded_t) in zip(
378+
opt_states.items(), sharded_t_dict.items()
379+
):
380+
logger.info(
381+
f"update optimizer state for table {table_config.name} with state shape {opt_state_t.shape}, rank={self._my_rank}, weight_count_per_rank={self._table_name_to_weight_count_per_rank.get(table_config.name, None)}"
382+
)
383+
sharded_t.local_shards()[0].tensor = opt_state_t
384+
create_virtual_table_local_metadata(
385+
# pyre-ignore [6]
386+
table_config.local_metadata,
387+
opt_state_t,
388+
self._my_rank,
389+
)
390+
for shard in sharded_t.local_shards():
391+
shard.metadata = table_config.local_metadata
392+
new_sharded_t = ShardedTensor._init_from_local_shards(
393+
sharded_t.local_shards(),
394+
None,
395+
None,
396+
process_group=self._pg,
397+
)
398+
sharded_t_dict[sharded_t_k] = new_sharded_t
399+
295400

296401
class EmbeddingFusedOptimizer(FusedOptimizer):
297402
def __init__( # noqa C901
@@ -1330,7 +1435,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
13301435
return
13311436

13321437
pmt_list, weight_ids_list, bucket_cnt_list = self.split_embedding_weights(
1333-
no_snapshot=False
1438+
no_snapshot=False, should_flush=True
13341439
)
13351440
emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
13361441
for emb_table in emb_table_config_copy:
@@ -1381,12 +1486,16 @@ def purge(self) -> None:
13811486
self.emb_module.lxu_cache_state.fill_(-1)
13821487

13831488
# pyre-ignore [15]
1384-
def split_embedding_weights(self, no_snapshot: bool = True) -> Tuple[
1489+
def split_embedding_weights(
1490+
self, no_snapshot: bool = True, should_flush: bool = False
1491+
) -> Tuple[
13851492
List[PartiallyMaterializedTensor],
13861493
Optional[List[torch.Tensor]],
13871494
Optional[List[torch.Tensor]],
13881495
]:
1389-
return self.emb_module.split_embedding_weights(no_snapshot)
1496+
return self.emb_module.split_embedding_weights(
1497+
no_snapshot, should_flush=should_flush
1498+
)
13901499

13911500
def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
13921501
# reset split weights during training

0 commit comments

Comments
 (0)