Skip to content

Commit c810597

Browse files
committed
fix optimizer resuming in FSDP and remove param/buffer precision
1 parent c455459 commit c810597

File tree

2 files changed

+15
-19
lines changed

2 files changed

+15
-19
lines changed

src/training/main.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ def get_latest_checkpoint(path: str, remote : bool):
8787

8888
def main(args):
8989
args = parse_args(args)
90-
9190
if torch.cuda.is_available():
9291
# This enables tf32 on Ampere GPUs which is only 8% slower than
9392
# float16 and almost as accurate as float32
@@ -306,9 +305,7 @@ def main(args):
306305
"fp32": torch.float32,
307306
}
308307
mixed_precision = MixedPrecision(
309-
param_dtype=type_name_to_class[args.precision],
310308
reduce_dtype=type_name_to_class[args.fsdp_gradient_reduction_precision],
311-
buffer_dtype=type_name_to_class[args.fsdp_buffer_precision],
312309
)
313310
layers = set()
314311
for module in model.modules():
@@ -408,6 +405,13 @@ def _param_name_without_fsdp_prefix(n):
408405
scaler = GradScaler() if args.precision == "amp" else None
409406
# optionally resume from a checkpoint
410407
start_epoch = 0
408+
if args.fsdp:
409+
FSDP.set_state_dict_type(
410+
model,
411+
StateDictType.FULL_STATE_DICT,
412+
FullStateDictConfig(rank0_only=False, offload_to_cpu=True),
413+
FullOptimStateDictConfig(rank0_only=False, offload_to_cpu=True),
414+
)
411415
if args.resume is not None:
412416
checkpoint = pt_load(args.resume, map_location='cpu')
413417
if 'epoch' in checkpoint:
@@ -419,7 +423,10 @@ def _param_name_without_fsdp_prefix(n):
419423
model.load_state_dict(sd)
420424
if optimizer is not None:
421425
if args.fsdp:
422-
sharded_state_dict = FSDP.optim_state_dict_to_load(checkpoint["optimizer"], model, optimizer)
426+
optimizer_state_dict = checkpoint["optimizer"]
427+
optimizer_state_dict['state']['logit_scale']['exp_avg'] = optimizer_state_dict['state']['logit_scale']['exp_avg'].view(1)
428+
optimizer_state_dict['state']['logit_scale']['exp_avg_sq'] = optimizer_state_dict['state']['logit_scale']['exp_avg_sq'].view(1)
429+
sharded_state_dict = FSDP.optim_state_dict_to_load(model, optimizer, optimizer_state_dict)
423430
optimizer.load_state_dict(sharded_state_dict)
424431
else:
425432
optimizer.load_state_dict(checkpoint["optimizer"])
@@ -431,13 +438,7 @@ def _param_name_without_fsdp_prefix(n):
431438
model.load_state_dict(checkpoint)
432439
logging.info(f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})")
433440

434-
if args.fsdp:
435-
FSDP.set_state_dict_type(
436-
model,
437-
StateDictType.FULL_STATE_DICT,
438-
FullStateDictConfig(rank0_only=True, offload_to_cpu=True),
439-
FullOptimStateDictConfig(rank0_only=True, offload_to_cpu=True),
440-
)
441+
441442
# initialize datasets
442443
data = get_data(args, (preprocess_train, preprocess_val), epoch=start_epoch, tokenizer=get_tokenizer(args.model))
443444
assert len(data), 'At least one train or eval dataset must be specified.'
@@ -506,14 +507,15 @@ def _param_name_without_fsdp_prefix(n):
506507

507508
# Saving checkpoints.
508509
if args.save_logs:
509-
510+
510511
checkpoint_dict = {
511512
"epoch": completed_epoch,
512513
"name": args.name,
513514
"state_dict": model.state_dict(),
514-
"optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict()
515+
"optimizer": FSDP.optim_state_dict(model, optimizer) if args.fsdp else optimizer.state_dict(),
515516

516517
}
518+
517519
if scaler is not None:
518520
checkpoint_dict["scaler"] = scaler.state_dict()
519521
if is_master(args):

src/training/params.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -398,12 +398,6 @@ def parse_args(args):
398398
nargs='+',
399399
help="Module names to wrap for gradient checkpointing when FSDP is used",
400400
)
401-
parser.add_argument(
402-
"--fsdp-buffer-precision",
403-
choices=["bf16", "fp16", "fp32"],
404-
default="fp32",
405-
help="FSDP floating point precision for buffers"
406-
)
407401
parser.add_argument(
408402
"--fsdp-gradient-reduction-precision",
409403
choices=["bf16", "fp16", "fp32"],

0 commit comments

Comments
 (0)