-
Notifications
You must be signed in to change notification settings - Fork 386
Fix lr scheduler #1261
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix lr scheduler #1261
Conversation
@@ -142,25 +142,26 @@ def linear_warmup_stable_decay( | |||
to ensure the learning rate does not drop below this minimum value. | |||
""" | |||
warmup_stable_steps = warmup_steps + stable_steps | |||
# if we are in the warmup phase, return the warmup progress | |||
# if warmup_steps is 0, we will go to the next phase | |||
if current_step < warmup_steps: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice fix, thank you @CarlosGomes98 ! I agree warmup_steps
should not +1, and it's not 0-indexed. For current_step
, because we always call self.step+=1
at the beginning of train(), current_step
is also not 0-indexed. (https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L434)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @tianyu-l for lr_scheduler change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a discussion in https://github.com/pytorch/torchtitan/pull/1010
. I'm wondering should we have unittests to verify different use cases for linear_warmup_stable_decay()
? This function is a pure Python function which should not be hard to add some test cases to see if we fixes both issues (the issue from the checkpoint and the issue from the previous divided by zero).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Glad that you found the bug!
I think the WSD lr scheduler is trickier than it sounds, due to ambiguous definition of warmup/stable/decay steps. See my earlier comment #1010 (review)
In particular, I believe your change would cause no learning to happen in the first iteration, unless lr_min
is specified.
It is OK to me if there is a consensus this should be the default behavior.
elif lr_decay_type == "cosine": | ||
curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) | ||
curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment | ||
return float(current_step / warmup_steps) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From what I read, if warmup=4, stable=2, decay=4, train steps=10, what will happen on each step is:
(Here I'm assuming lr_decay_type == "linear"
and lr_min=0
)
step 1. lr=0, not learning
step 2. lr=0.25
step 3. lr=0.5
step 4. lr=0.75
step 5. lr=1
step 6. lr=1
step 7. lr=1
step 8. lr=0.75
step 9. lr=0.5
step 10. lr=0.25
This will cause no learning in the first step. Is this expected?
Addresses #1213, whose root cause was an arithmetic error in the lr scheduler.
The lr_scheduler adds 1 to the current step, which is not necessary and causes several errors in the lr steps.
Also took the chance to slightly simplify the logic (in my opinion) using early returns, as well as explicitly dealing with the edge case where decay_ratio is 0.
Copying over the reasoning I shared in #1213
Lets set an example scenario where our lr schedule has warmup_steps of 5 and a target lr of 0.0008.
We would expect the lr to increase in each step by
0.0008/5 = 0.00016
. However, what happens is below:This is clearly wrong, as the lr is increasing by 0.00013 except for the last step, which increases by more to make up the difference.
We can easily resolve this by removing the assumptions in the comments
# 0-indexed step, hence + 1 adjustments
.Making those changes, we get the desired behaviour of:
This same behaviour was causing the bug I described in #1213 . It was making it so, in the final step, we were going 1 over the max_steps expected, and thus setting the multiplicative lr adjustment factor to 0.