@@ -87,7 +87,6 @@ def get_latest_checkpoint(path: str, remote : bool):
8787
8888def main (args ):
8989 args = parse_args (args )
90-
9190 if torch .cuda .is_available ():
9291 # This enables tf32 on Ampere GPUs which is only 8% slower than
9392 # float16 and almost as accurate as float32
@@ -306,9 +305,7 @@ def main(args):
306305 "fp32" : torch .float32 ,
307306 }
308307 mixed_precision = MixedPrecision (
309- param_dtype = type_name_to_class [args .precision ],
310308 reduce_dtype = type_name_to_class [args .fsdp_gradient_reduction_precision ],
311- buffer_dtype = type_name_to_class [args .fsdp_buffer_precision ],
312309 )
313310 layers = set ()
314311 for module in model .modules ():
@@ -408,6 +405,13 @@ def _param_name_without_fsdp_prefix(n):
408405 scaler = GradScaler () if args .precision == "amp" else None
409406 # optionally resume from a checkpoint
410407 start_epoch = 0
408+ if args .fsdp :
409+ FSDP .set_state_dict_type (
410+ model ,
411+ StateDictType .FULL_STATE_DICT ,
412+ FullStateDictConfig (rank0_only = False , offload_to_cpu = True ),
413+ FullOptimStateDictConfig (rank0_only = False , offload_to_cpu = True ),
414+ )
411415 if args .resume is not None :
412416 checkpoint = pt_load (args .resume , map_location = 'cpu' )
413417 if 'epoch' in checkpoint :
@@ -419,7 +423,10 @@ def _param_name_without_fsdp_prefix(n):
419423 model .load_state_dict (sd )
420424 if optimizer is not None :
421425 if args .fsdp :
422- sharded_state_dict = FSDP .optim_state_dict_to_load (checkpoint ["optimizer" ], model , optimizer )
426+ optimizer_state_dict = checkpoint ["optimizer" ]
427+ optimizer_state_dict ['state' ]['logit_scale' ]['exp_avg' ] = optimizer_state_dict ['state' ]['logit_scale' ]['exp_avg' ].view (1 )
428+ optimizer_state_dict ['state' ]['logit_scale' ]['exp_avg_sq' ] = optimizer_state_dict ['state' ]['logit_scale' ]['exp_avg_sq' ].view (1 )
429+ sharded_state_dict = FSDP .optim_state_dict_to_load (model , optimizer , optimizer_state_dict )
423430 optimizer .load_state_dict (sharded_state_dict )
424431 else :
425432 optimizer .load_state_dict (checkpoint ["optimizer" ])
@@ -431,13 +438,7 @@ def _param_name_without_fsdp_prefix(n):
431438 model .load_state_dict (checkpoint )
432439 logging .info (f"=> loaded checkpoint '{ args .resume } ' (epoch { start_epoch } )" )
433440
434- if args .fsdp :
435- FSDP .set_state_dict_type (
436- model ,
437- StateDictType .FULL_STATE_DICT ,
438- FullStateDictConfig (rank0_only = True , offload_to_cpu = True ),
439- FullOptimStateDictConfig (rank0_only = True , offload_to_cpu = True ),
440- )
441+
441442 # initialize datasets
442443 data = get_data (args , (preprocess_train , preprocess_val ), epoch = start_epoch , tokenizer = get_tokenizer (args .model ))
443444 assert len (data ), 'At least one train or eval dataset must be specified.'
@@ -506,14 +507,15 @@ def _param_name_without_fsdp_prefix(n):
506507
507508 # Saving checkpoints.
508509 if args .save_logs :
509-
510+
510511 checkpoint_dict = {
511512 "epoch" : completed_epoch ,
512513 "name" : args .name ,
513514 "state_dict" : model .state_dict (),
514- "optimizer" : FSDP .optim_state_dict (model , optimizer ) if args .fsdp else optimizer .state_dict ()
515+ "optimizer" : FSDP .optim_state_dict (model , optimizer ) if args .fsdp else optimizer .state_dict (),
515516
516517 }
518+
517519 if scaler is not None :
518520 checkpoint_dict ["scaler" ] = scaler .state_dict ()
519521 if is_master (args ):
0 commit comments