Skip to content

Commit 1347816

Browse files
committed
use use_orig_params=True (thanks to @nkflash) to use original parameter names to avoid erroneous parameter decay, and decay params by constructing a set of parameter names to decay before FSDP wrapping (thanks to @rwightman)
1 parent ca50a25 commit 1347816

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/training/main.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,10 @@ def main(args):
226226
aug_cfg=args.aug_cfg,
227227
output_dict=True,
228228
)
229+
# Prepare parameters to decay
230+
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
231+
parameters_to_decay = set(n for n, p in model.named_parameters() if exclude(n,p))
232+
229233
random_seed(args.seed, args.rank)
230234

231235
if args.trace:
@@ -297,6 +301,7 @@ def main(args):
297301
transformer_auto_wrap_policy,
298302
transformer_layer_cls=layers,
299303
),
304+
use_orig_params=True,
300305
device_id=device,
301306
)
302307
# avoid "RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory."
@@ -350,14 +355,17 @@ def main(args):
350355

351356
if args.train_data or args.dataset_type == "synthetic":
352357
assert not args.trace, 'Cannot train with traced model'
353-
354-
exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n
355-
include = lambda n, p: not exclude(n, p)
356-
357358
named_parameters = list(model.named_parameters())
358-
gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
359-
rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad]
360-
359+
if args.distributed_engine == "fsdp":
360+
def _param_name_without_fsdp_prefix(n):
361+
n = n.replace("_fsdp_wrapped_module.", "")
362+
n = n.replace("._checkpoint_wrapped_module", "")
363+
return n
364+
gain_or_bias_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) in parameters_to_decay and p.requires_grad]
365+
rest_params = [p for n, p in named_parameters if _param_name_without_fsdp_prefix(n) not in parameters_to_decay and p.requires_grad]
366+
else:
367+
gain_or_bias_params = [p for n, p in named_parameters if n in parameters_to_decay and p.requires_grad]
368+
rest_params = [p for n, p in named_parameters if n not in parameters_to_decay and p.requires_grad]
361369
optimizer = optim.AdamW(
362370
[
363371
{"params": gain_or_bias_params, "weight_decay": 0.},

0 commit comments

Comments
 (0)