55
55
from torchrec .distributed .embedding_kernel import (
56
56
BaseEmbedding ,
57
57
create_virtual_sharded_tensors ,
58
+ create_virtual_table_local_metadata ,
58
59
get_state_dict ,
59
60
)
60
61
from torchrec .distributed .embedding_types import (
@@ -206,7 +207,9 @@ def _populate_zero_collision_tbe_params(
206
207
bucket_sizes : List [int ] = [size for _ , _ , size in sharded_local_buckets ]
207
208
208
209
tbe_params ["kv_zch_params" ] = KVZCHParams (
209
- bucket_offsets = bucket_offsets , bucket_sizes = bucket_sizes
210
+ bucket_offsets = bucket_offsets ,
211
+ bucket_sizes = bucket_sizes ,
212
+ enable_optimizer_offloading = False ,
210
213
)
211
214
212
215
@@ -283,6 +286,53 @@ def __init__( # noqa C901
283
286
table_name_to_weight_count_per_rank
284
287
)
285
288
289
+ # pyre-ignore [33]
290
+ state : Dict [Any , Any ] = {}
291
+ param_group : Dict [str , Any ] = {
292
+ "params" : [],
293
+ "lr" : emb_module .get_learning_rate (),
294
+ }
295
+
296
+ params : Dict [str , Union [torch .Tensor , ShardedTensor ]] = {}
297
+
298
+ sorted_id_tensors = (
299
+ [
300
+ sharded_t ._local_shards [0 ].tensor
301
+ for sharded_t in self ._sharded_embedding_weight_ids
302
+ ]
303
+ if self ._sharded_embedding_weight_ids is not None
304
+ else None
305
+ )
306
+
307
+ all_optimizer_states = emb_module .get_optimizer_state (
308
+ sorted_id_tensor = sorted_id_tensors
309
+ )
310
+ opt_param_list = [param ["momentum1" ] for param in all_optimizer_states ]
311
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
312
+ for emb_table in emb_table_config_copy :
313
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
314
+ opt_sharded_t_list = create_virtual_sharded_tensors (
315
+ emb_table_config_copy , opt_param_list , self ._pg
316
+ )
317
+
318
+ for (
319
+ emb_config ,
320
+ sharded_weight ,
321
+ opt_sharded_t ,
322
+ ) in zip (
323
+ emb_table_config_copy ,
324
+ sharded_embedding_weights_by_table ,
325
+ opt_sharded_t_list ,
326
+ ):
327
+ param_key = emb_config .name + ".weight"
328
+ state [sharded_weight ] = {}
329
+ param_group ["params" ].append (sharded_weight )
330
+ params [param_key ] = sharded_weight
331
+
332
+ state [sharded_weight ][f"{ emb_config .name } .momentum1" ] = opt_sharded_t
333
+
334
+ super ().__init__ (params , state , [param_group ])
335
+
286
336
def zero_grad (self , set_to_none : bool = False ) -> None :
287
337
# pyre-ignore [16]
288
338
self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
@@ -292,6 +342,59 @@ def step(self, closure: Any = None) -> None:
292
342
# pyre-ignore [16]
293
343
self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
294
344
345
+ def set_sharded_embedding_weight_ids (
346
+ self , sharded_embedding_weight_ids : Optional [List [ShardedTensor ]]
347
+ ) -> None :
348
+ self ._sharded_embedding_weight_ids = sharded_embedding_weight_ids
349
+
350
+ def _post_state_dict_hook (self , curr_state : Dict [str , Any ]) -> None :
351
+ logger .info ("update optimizer state dict in state_dict_post_hook" )
352
+ embedding_weight_ids = (
353
+ [
354
+ sharded_t ._local_shards [0 ].tensor
355
+ for sharded_t in self ._sharded_embedding_weight_ids
356
+ ]
357
+ if self ._sharded_embedding_weight_ids is not None
358
+ else None
359
+ )
360
+ all_optimizer_states = self ._emb_module .get_optimizer_state (
361
+ embedding_weight_ids , no_snapshot = False , should_flush = True
362
+ )
363
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
364
+ for emb_table in emb_table_config_copy :
365
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
366
+
367
+ # The order of table_config is determined so put it as outer-loop for consistent traverse order across ranks
368
+ for table_config , opt_states in zip (
369
+ emb_table_config_copy ,
370
+ all_optimizer_states ,
371
+ ):
372
+ for key , sharded_t_dict in curr_state .items ():
373
+ # update zero collision table's optimizer state
374
+ if f".{ table_config .name } .weight" in key :
375
+ for (_ , opt_state_t ), (sharded_t_k , sharded_t ) in zip (
376
+ opt_states .items (), sharded_t_dict .items ()
377
+ ):
378
+ logger .info (
379
+ f"update optimizer state for table { table_config .name } with state shape { opt_state_t .shape } , rank={ self ._my_rank } , weight_count_per_rank={ self ._table_name_to_weight_count_per_rank .get (table_config .name , None )} "
380
+ )
381
+ sharded_t .local_shards ()[0 ].tensor = opt_state_t
382
+ create_virtual_table_local_metadata (
383
+ # pyre-ignore [6]
384
+ table_config .local_metadata ,
385
+ opt_state_t ,
386
+ self ._my_rank ,
387
+ )
388
+ for shard in sharded_t .local_shards ():
389
+ shard .metadata = table_config .local_metadata
390
+ new_sharded_t = ShardedTensor ._init_from_local_shards (
391
+ sharded_t .local_shards (),
392
+ None ,
393
+ None ,
394
+ process_group = self ._pg ,
395
+ )
396
+ sharded_t_dict [sharded_t_k ] = new_sharded_t
397
+
295
398
296
399
class EmbeddingFusedOptimizer (FusedOptimizer ):
297
400
def __init__ ( # noqa C901
@@ -1330,7 +1433,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1330
1433
return
1331
1434
1332
1435
pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
1333
- no_snapshot = False
1436
+ no_snapshot = False , should_flush = True
1334
1437
)
1335
1438
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
1336
1439
for emb_table in emb_table_config_copy :
@@ -1381,12 +1484,16 @@ def purge(self) -> None:
1381
1484
self .emb_module .lxu_cache_state .fill_ (- 1 )
1382
1485
1383
1486
# pyre-ignore [15]
1384
- def split_embedding_weights (self , no_snapshot : bool = True ) -> Tuple [
1487
+ def split_embedding_weights (
1488
+ self , no_snapshot : bool = True , should_flush : bool = False
1489
+ ) -> Tuple [
1385
1490
List [PartiallyMaterializedTensor ],
1386
1491
Optional [List [torch .Tensor ]],
1387
1492
Optional [List [torch .Tensor ]],
1388
1493
]:
1389
- return self .emb_module .split_embedding_weights (no_snapshot )
1494
+ return self .emb_module .split_embedding_weights (
1495
+ no_snapshot , should_flush = should_flush
1496
+ )
1390
1497
1391
1498
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
1392
1499
# reset split weights during training
0 commit comments