Skip to content

Commit f891252

Browse files
committed
use use_orig_params=True (thanks to @nkflash) to use original parameter names to avoid erroneous parameter decay, and decay params by constructing a set of parameter names to decay before FSDP wrapping (thanks to @rwightman)
1 parent 085b978 commit f891252

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

src/training/main.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,9 @@ def main(args):
243243
precision=args.precision,
244244
output_dict=True,
245245
)
246+
# Prepare parameters to decay
247+
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
248+
parameters_to_decay = set(n for n, p in model.named_parameters() if not exclude(n,p))
246249

247250
random_seed(args.seed, args.rank)
248251

@@ -314,6 +317,7 @@ def main(args):
314317
transformer_auto_wrap_policy,
315318
transformer_layer_cls=layers,
316319
),
320+
use_orig_params=True,
317321
device_id=device,
318322
)
319323
# avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory."
@@ -367,14 +371,17 @@ def main(args):
367371

368372
if args.train_data or args.dataset_type == "synthetic":
369373
assert not args.trace, 'Cannot train with traced model'
370-
371-
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
372-
include = lambda n, p: not exclude(n, p)
373-
374374
named_parameters = list(model.named_parameters())
375-
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
376-
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
377-
375+
if args.distributed_engine == "fsdp":
376+
def _param_name_without_fsdp_prefix(n):
377+
n = n.replace("_fsdp_wrapped_module.", "")
378+
n = n.replace("._checkpoint_wrapped_module", "")
379+
return n
380+
gain_or_bias_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) not in parameters_to_decay and p.requires_grad]
381+
rest_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) in parameters_to_decay and p.requires_grad]
382+
else:
383+
gain_or_bias_params = [p for n, p in named_parameters if n not in parameters_to_decay and p.requires_grad]
384+
rest_params = [p for n, p in named_parameters if n in parameters_to_decay and p.requires_grad]
378385
optimizer = optim.AdamW(
379386
[
380387
{"params": gain_or_bias_params, "weight_decay": 0.},

0 commit comments

Comments
 (0)