55
55
from torchrec .distributed .embedding_kernel import (
56
56
BaseEmbedding ,
57
57
create_virtual_sharded_tensors ,
58
+ create_virtual_table_local_metadata ,
58
59
get_state_dict ,
59
60
)
60
61
from torchrec .distributed .embedding_types import (
@@ -206,7 +207,9 @@ def _populate_zero_collision_tbe_params(
206
207
bucket_sizes : List [int ] = [size for _ , _ , size in sharded_local_buckets ]
207
208
208
209
tbe_params ["kv_zch_params" ] = KVZCHParams (
209
- bucket_offsets = bucket_offsets , bucket_sizes = bucket_sizes
210
+ bucket_offsets = bucket_offsets ,
211
+ bucket_sizes = bucket_sizes ,
212
+ enable_optimizer_offloading = False ,
210
213
)
211
214
212
215
@@ -283,6 +286,53 @@ def __init__( # noqa C901
283
286
table_name_to_weight_count_per_rank
284
287
)
285
288
289
+ # pyre-ignore [33]
290
+ state : Dict [Any , Any ] = {}
291
+ param_group : Dict [str , Any ] = {
292
+ "params" : [],
293
+ "lr" : emb_module .get_learning_rate (),
294
+ }
295
+
296
+ params : Dict [str , Union [torch .Tensor , ShardedTensor ]] = {}
297
+
298
+ sorted_id_tensors = (
299
+ [
300
+ sharded_t ._local_shards [0 ].tensor
301
+ for sharded_t in self ._sharded_embedding_weight_ids
302
+ ]
303
+ if self ._sharded_embedding_weight_ids is not None
304
+ else None
305
+ )
306
+
307
+ all_optimizer_states = emb_module .get_optimizer_state (
308
+ sorted_id_tensor = sorted_id_tensors
309
+ )
310
+ opt_param_list = [param ["momentum1" ] for param in all_optimizer_states ]
311
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
312
+ for emb_table in emb_table_config_copy :
313
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
314
+ opt_sharded_t_list = create_virtual_sharded_tensors (
315
+ emb_table_config_copy , opt_param_list , self ._pg
316
+ )
317
+
318
+ for (
319
+ emb_config ,
320
+ sharded_weight ,
321
+ opt_sharded_t ,
322
+ ) in zip (
323
+ emb_table_config_copy ,
324
+ sharded_embedding_weights_by_table ,
325
+ opt_sharded_t_list ,
326
+ ):
327
+ param_key = emb_config .name + ".weight"
328
+ state [sharded_weight ] = {}
329
+ param_group ["params" ].append (sharded_weight )
330
+ params [param_key ] = sharded_weight
331
+
332
+ state [sharded_weight ][f"{ emb_config .name } .momentum1" ] = opt_sharded_t
333
+
334
+ super ().__init__ (params , state , [param_group ])
335
+
286
336
def zero_grad (self , set_to_none : bool = False ) -> None :
287
337
# pyre-ignore [16]
288
338
self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
@@ -292,6 +342,61 @@ def step(self, closure: Any = None) -> None:
292
342
# pyre-ignore [16]
293
343
self ._emb_module .set_learning_rate (self .param_groups [0 ]["lr" ])
294
344
345
+ def set_sharded_embedding_weight_ids (
346
+ self , sharded_embedding_weight_ids : Optional [List [ShardedTensor ]]
347
+ ) -> None :
348
+ self ._sharded_embedding_weight_ids = sharded_embedding_weight_ids
349
+
350
+ def _post_state_dict_hook (self , curr_state : Dict [str , Any ]) -> None :
351
+ logger .info ("update optimizer state dict in state_dict_post_hook" )
352
+ embedding_weight_ids = (
353
+ [
354
+ sharded_t ._local_shards [0 ].tensor
355
+ for sharded_t in self ._sharded_embedding_weight_ids
356
+ ]
357
+ if self ._sharded_embedding_weight_ids is not None
358
+ else None
359
+ )
360
+ all_optimizer_states = self ._emb_module .get_optimizer_state (
361
+ embedding_weight_ids ,
362
+ no_snapshot = False ,
363
+ should_flush = False , # get embedding weights already flushed, no need to flush again here
364
+ )
365
+ emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
366
+ for emb_table in emb_table_config_copy :
367
+ emb_table .local_metadata .placement ._device = torch .device ("cpu" )
368
+
369
+ # The order of table_config is determined so put it as outer-loop for consistent traverse order across ranks
370
+ for table_config , opt_states in zip (
371
+ emb_table_config_copy ,
372
+ all_optimizer_states ,
373
+ ):
374
+ for key , sharded_t_dict in curr_state .items ():
375
+ # update zero collision table's optimizer state
376
+ if f".{ table_config .name } .weight" in key :
377
+ for (_ , opt_state_t ), (sharded_t_k , sharded_t ) in zip (
378
+ opt_states .items (), sharded_t_dict .items ()
379
+ ):
380
+ logger .info (
381
+ f"update optimizer state for table { table_config .name } with state shape { opt_state_t .shape } , rank={ self ._my_rank } , weight_count_per_rank={ self ._table_name_to_weight_count_per_rank .get (table_config .name , None )} "
382
+ )
383
+ sharded_t .local_shards ()[0 ].tensor = opt_state_t
384
+ create_virtual_table_local_metadata (
385
+ # pyre-ignore [6]
386
+ table_config .local_metadata ,
387
+ opt_state_t ,
388
+ self ._my_rank ,
389
+ )
390
+ for shard in sharded_t .local_shards ():
391
+ shard .metadata = table_config .local_metadata
392
+ new_sharded_t = ShardedTensor ._init_from_local_shards (
393
+ sharded_t .local_shards (),
394
+ None ,
395
+ None ,
396
+ process_group = self ._pg ,
397
+ )
398
+ sharded_t_dict [sharded_t_k ] = new_sharded_t
399
+
295
400
296
401
class EmbeddingFusedOptimizer (FusedOptimizer ):
297
402
def __init__ ( # noqa C901
@@ -1330,7 +1435,7 @@ def get_named_split_embedding_weights_snapshot(self, prefix: str = "") -> Iterat
1330
1435
return
1331
1436
1332
1437
pmt_list , weight_ids_list , bucket_cnt_list = self .split_embedding_weights (
1333
- no_snapshot = False
1438
+ no_snapshot = False , should_flush = True
1334
1439
)
1335
1440
emb_table_config_copy = copy .deepcopy (self ._config .embedding_tables )
1336
1441
for emb_table in emb_table_config_copy :
@@ -1381,12 +1486,16 @@ def purge(self) -> None:
1381
1486
self .emb_module .lxu_cache_state .fill_ (- 1 )
1382
1487
1383
1488
# pyre-ignore [15]
1384
- def split_embedding_weights (self , no_snapshot : bool = True ) -> Tuple [
1489
+ def split_embedding_weights (
1490
+ self , no_snapshot : bool = True , should_flush : bool = False
1491
+ ) -> Tuple [
1385
1492
List [PartiallyMaterializedTensor ],
1386
1493
Optional [List [torch .Tensor ]],
1387
1494
Optional [List [torch .Tensor ]],
1388
1495
]:
1389
- return self .emb_module .split_embedding_weights (no_snapshot )
1496
+ return self .emb_module .split_embedding_weights (
1497
+ no_snapshot , should_flush = should_flush
1498
+ )
1390
1499
1391
1500
def forward (self , features : KeyedJaggedTensor ) -> torch .Tensor :
1392
1501
# reset split weights during training
0 commit comments