Skip to content

Commit 4817a91

Browse files
Myle Ottfacebook-github-bot
Myle Ott
authored andcommittedDec 5, 2020
Cleanup CosineLRScheduler and change defaults (#1487)
Summary: Pull Request resolved: fairinternal/fairseq-py#1487 Here's the code for CosineLRScheduler that I used as a reference: https://github.com/pytorch/fairseq/blob/577e4fa78a295fd7cd3ee7e9fd4b936ca800ebea/fairseq/optim/lr_scheduler/cosine_lr_schedul In the reference: - `warmup_init_lr` defaults to `args.lr[0]` - `warmup_end_lr` defaults to `args.max_lr` - `min_lr` defaults to `args.lr[0]` (note that there's also a `args.min_lr` option defined in the global fairseq config, but this is unused by the cosine scheduler) - `max_lr` is a required option This diff removes `max_lr` and replaces it with `lr[0]` to be more consistent with other LR schedulers. We then add an explicit `min_lr` option to the Config. Test Plan: Imported from OSS Reviewed By: alexeib Differential Revision: D25342180 Pulled By: myleott fbshipit-source-id: 61281666e68839da8efc4714c2ce8c49dc4c8e6e
1 parent 72a25a4 commit 4817a91

File tree

6 files changed

+26
-26
lines changed

6 files changed

+26
-26
lines changed
 

‎examples/language_model/README.adaptive_inputs.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ fairseq-train --task language_modeling \
1919
data-bin/wikitext-103 \
2020
--save-dir checkpoints/transformer_wikitext-103 \
2121
--arch transformer_lm_wiki103 \
22-
--max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
23-
--warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1 \
22+
--max-update 286000 --lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 \
23+
--warmup-updates 16000 --warmup-init-lr 1e-07 --stop-min-lr 1e-09 --optimizer nag --min-lr 0.0001 --clip-norm 0.1 \
2424
--criterion adaptive_loss --max-tokens 3072 --update-freq 3 --tokens-per-sample 3072 --seed 1 \
2525
--sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d
2626
```

‎examples/pay_less_attention_paper/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
140140
--stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
141141
--ddp-backend=no_c10d --max-tokens 3584 \
142142
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
143-
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \
143+
--lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \
144144
--t-mult 1 --lr-period-updates 20000 \
145145
--arch lightconv_wmt_en_de_big --save-dir $SAVE \
146146
--dropout 0.3 --attention-dropout 0.1 --weight-dropout 0.1 \
@@ -165,7 +165,7 @@ python -m torch.distributed.launch --nproc_per_node 8 $(which fairseq-train) \
165165
--stop-min-lr 1e-09 --update-freq 16 --attention-dropout 0.1 --keep-last-epochs 10 \
166166
--ddp-backend=no_c10d --max-tokens 3584 \
167167
--lr-scheduler cosine --warmup-init-lr 1e-7 --warmup-updates 10000 \
168-
--lr-shrink 1 --max-lr 0.001 --lr 1e-7 --warmup-init-lr 1e-07 \
168+
--lr-shrink 1 --lr 0.001 --min-lr 1e-7 --warmup-init-lr 1e-07 \
169169
--t-mult 1 --lr-period-updates 70000 \
170170
--arch lightconv_wmt_en_fr_big --save-dir $SAVE \
171171
--dropout 0.1 --attention-dropout 0.1 --weight-dropout 0.1 \

‎examples/quant_noise/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \
208208
--ddp-backend no_c10d \
209209
--decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 \
210210
--decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \
211-
--lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 1.0 --t-mult 2.0 \
211+
--min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 1.0 --t-mult 2.0 \
212212
--max-tokens 3072 --tokens-per-sample 3072 --momentum 0.99 --optimizer nag \
213213
--sample-break-mode none --update-freq 3 \
214214
--warmup-init-lr 1e-07 --warmup-updates 16000 \
@@ -269,7 +269,7 @@ fairseq-train --task language_modeling /path/to/wikitext-103/data \
269269
--ddp-backend no_c10d \
270270
--decoder-attention-heads 8 --decoder-embed-dim 1024 --decoder-ffn-embed-dim 4096 --decoder-input-dim 1024 --decoder-layers 16 --decoder-normalize-before --decoder-output-dim 1024 \
271271
--fp16 --keep-last-epochs -1 \
272-
--lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --max-lr 0.05 --stop-min-lr 1e-09 \
272+
--min-lr 0.0001 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75 --lr 0.05 --stop-min-lr 1e-09 \
273273
--max-tokens 2944 --tokens-per-sample 2944\
274274
--momentum 0.99 --no-epoch-checkpoints --no-progress-bar --optimizer nag --required-batch-size-multiple 8 \
275275
--sample-break-mode none --t-mult 2.0 --skip-invalid-size-inputs-valid-test \

‎examples/truncated_bptt/README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 fairseq-train \
3737
--arch transformer_xl --n-layer 16 --d-model 410 --n-head 10 \
3838
--d-head 41 --d-inner 2100 --dropout 0.1 --dropatt 0.0 --mem-len 150 \
3939
--optimizer adam --clip-norm 0.25 \
40-
--lr-scheduler cosine --warmup-updates 0 --lr 0.0 --max-lr 0.00025 \
40+
--lr-scheduler cosine --warmup-updates 0 --min-lr 0.0 --lr 0.00025 \
4141
--log-format json --log-interval 25 \
4242
--fp16
4343
```

‎examples/wav2vec/README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa
186186

187187
```
188188
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 --save-interval 1 --no-epoch-checkpoints \
189-
--arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --max-lr 0.005 --lr-scheduler cosine \
189+
--arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 --optimizer adam --lr 0.005 --lr-scheduler cosine \
190190
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1)] \
191191
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
192192
--skip-connections-agg --residual-scale 0.5 --log-compression --warmup-updates 500 --warmup-init-lr 1e-07 --criterion wav2vec --num-negatives 10 \
@@ -244,8 +244,8 @@ $ python examples/wav2vec/wav2vec_manifest.py /path/to/waves --dest /manifest/pa
244244

245245
```
246246
$ python train.py /manifest/path --save-dir /model/path --num-workers 6 --fp16 --max-update 400000 \
247-
--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --lr 1e-06 --stop-min-lr 1e-09 \
248-
--optimizer adam --max-lr 1e-05 --lr-scheduler cosine \
247+
--save-interval 1 --no-epoch-checkpoints --arch wav2vec --task audio_pretraining --min-lr 1e-06 --stop-min-lr 1e-09 \
248+
--optimizer adam --lr 1e-05 --lr-scheduler cosine \
249249
--conv-feature-layers [(512, 10, 5), (512, 8, 4), (512, 4, 2), (512, 4, 2), (512, 4, 2), (512, 1, 1), (512, 1, 1), (512, 1, 1)] \
250250
--conv-aggregator-layers [(512, 2, 1), (512, 3, 1), (512, 4, 1), (512, 5, 1), (512, 6, 1), (512, 7, 1), (512, 8, 1), (512, 9, 1), (512, 10, 1), (512, 11, 1), (512, 12, 1), (512, 13, 1)] \
251251
--activation gelu --offset auto --skip-connections-agg --residual-scale 0.5 \

‎fairseq/optim/lr_scheduler/cosine_lr_scheduler.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,11 @@ class CosineLRScheduleConfig(FairseqDataclass):
2626
"help": "initial learning rate during warmup phase; default is cfg.lr"
2727
},
2828
)
29-
max_lr: float = field(
30-
default=1.0, metadata={"help": "max learning rate, must be more than cfg.lr"}
29+
lr: List[float] = field(
30+
default=II("optimization.lr"),
31+
metadata={"help": "max learning rate, must be more than cfg.min_lr"},
3132
)
33+
min_lr: float = field(default=0.0, metadata={"help": "min learning rate"})
3234
t_mult: float = field(
3335
default=1.0, metadata={"help": "factor to grow the length of each period"}
3436
)
@@ -38,7 +40,7 @@ class CosineLRScheduleConfig(FairseqDataclass):
3840
lr_shrink: float = field(
3941
default=0.1, metadata={"help": "shrink factor for annealing"}
4042
)
41-
lr: List[float] = II("optimization.lr")
43+
# This is not required, but is for convenience in inferring lr_period_updates
4244
max_update: int = II("optimization.max_update")
4345

4446

@@ -50,7 +52,7 @@ class CosineLRSchedule(FairseqLRScheduler):
5052
5153
We also support a warmup phase where we linearly increase the learning rate
5254
from some initial learning rate (``--warmup-init-lr``) until the configured
53-
max learning rate (``--max-lr``).
55+
max learning rate (``--lr``).
5456
5557
During warmup::
5658
@@ -59,7 +61,7 @@ class CosineLRSchedule(FairseqLRScheduler):
5961
6062
After warmup::
6163
62-
lr = lr_min + 0.5*(lr_max - lr_min)*(1 + cos(t_curr / t_i))
64+
lr = cfg.min_lr + 0.5*(cfg.lr - cfg.min_lr)*(1 + cos(t_curr / t_i))
6365
6466
where ``t_curr`` is current percentage of updates within the current period
6567
range and ``t_i`` is the current period range, which is scaled by ``t_mul``
@@ -74,23 +76,21 @@ def __init__(self, cfg: CosineLRScheduleConfig, fairseq_optimizer):
7476
f" Consider --lr-scheduler=fixed instead. ({cfg.lr})"
7577
)
7678

77-
warmup_end_lr = cfg.max_lr
78-
lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
79-
if cfg.warmup_init_lr < 0:
80-
cfg.warmup_init_lr = lr
79+
self.max_lr = cfg.lr[0] if isinstance(cfg.lr, Collection) else cfg.lr
80+
assert (
81+
self.max_lr > cfg.min_lr
82+
), f"max_lr (={cfg.lr}) must be more than min_lr (={cfg.min_lr})"
8183

82-
# default min_lr=-1 -> cosine anneale to lr=0.0
83-
# otherwise pick min_lr from config
84-
self.min_lr = cfg.min_lr if cfg.min_lr > 0.0 else 0.0
85-
self.max_lr = lr
86-
assert self.max_lr > self.min_lr, "max_lr must be more than lr"
84+
warmup_end_lr = self.max_lr
85+
if cfg.warmup_init_lr < 0:
86+
cfg.warmup_init_lr = cfg.min_lr
8787

8888
self.t_mult = cfg.t_mult
8989
self.period = cfg.lr_period_updates
9090

9191
if self.period <= 0:
9292
assert (
93-
cfg.max_update >= 0
93+
cfg.max_update > 0
9494
), "Either --max_update or --lr-period-updates must be set"
9595
self.period = cfg.max_update - cfg.warmup_updates
9696

@@ -136,7 +136,7 @@ def step_update(self, num_updates):
136136
t_curr = curr_updates - (self.period * i)
137137

138138
lr_shrink = self.lr_shrink ** i
139-
min_lr = self.min_lr * lr_shrink
139+
min_lr = self.cfg.min_lr * lr_shrink
140140
max_lr = self.max_lr * lr_shrink
141141

142142
self.lr = min_lr + 0.5 * (max_lr - min_lr) * (

0 commit comments

Comments
 (0)