Skip to content

Fix lr scheduler #1261

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

CarlosGomes98
Copy link
Contributor

@CarlosGomes98 CarlosGomes98 commented Jun 4, 2025

Addresses #1213, whose root cause was an arithmetic error in the lr scheduler.

The lr_scheduler adds 1 to the current step, which is not necessary and causes several errors in the lr steps.

Also took the chance to slightly simplify the logic (in my opinion) using early returns, as well as explicitly dealing with the edge case where decay_ratio is 0.

Copying over the reasoning I shared in #1213

Lets set an example scenario where our lr schedule has warmup_steps of 5 and a target lr of 0.0008.
We would expect the lr to increase in each step by 0.0008/5 = 0.00016. However, what happens is below:

[rank0]:[titan] 2025-06-04 14:42:54,945 - root - INFO - Step 1 lr: 0.00013333333333333334
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - step:  1  loss:  1.7905  memory: 19.94GiB(42.08%)  tps: 519,779  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:54,946 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - Step 2 lr: 0.0002666666666666667
[rank0]:[titan] 2025-06-04 14:42:55,316 - root - INFO - step:  2  loss:  1.9953  memory: 20.45GiB(43.17%)  tps: 2,121,432  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - Step 3 lr: 0.0004
[rank0]:[titan] 2025-06-04 14:42:55,686 - root - INFO - step:  3  loss:  1.7081  memory: 20.45GiB(43.17%)  tps: 2,125,984  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - Step 4 lr: 0.0005333333333333334
[rank0]:[titan] 2025-06-04 14:42:56,079 - root - INFO - step:  4  loss:  1.6267  memory: 20.45GiB(43.17%)  tps: 2,004,664  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:42:56,476 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:42:56,477 - root - INFO - step:  5  loss:  1.5076  memory: 20.45GiB(43.17%)  tps: 1,977,805  tflops: 0.00  mfu: 0.00%

This is clearly wrong, as the lr is increasing by 0.00013 except for the last step, which increases by more to make up the difference.

We can easily resolve this by removing the assumptions in the comments # 0-indexed step, hence + 1 adjustments.
Making those changes, we get the desired behaviour of:

[rank0]:[titan] 2025-06-04 14:46:48,496 - root - INFO - Step 1 lr: 0.00016
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - step:  1  loss:  1.8309  memory: 19.94GiB(42.08%)  tps: 512,095  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:48,497 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - Step 2 lr: 0.00032
[rank0]:[titan] 2025-06-04 14:46:48,867 - root - INFO - step:  2  loss:  1.9606  memory: 20.45GiB(43.17%)  tps: 2,124,707  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - Step 3 lr: 0.00048
[rank0]:[titan] 2025-06-04 14:46:49,233 - root - INFO - step:  3  loss:  1.6091  memory: 20.45GiB(43.17%)  tps: 2,147,417  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:49,640 - root - INFO - Step 4 lr: 0.00064
[rank0]:[titan] 2025-06-04 14:46:49,641 - root - INFO - step:  4  loss:  1.4623  memory: 20.45GiB(43.17%)  tps: 1,931,065  tflops: 0.00  mfu: 0.00%
[rank0]:[titan] 2025-06-04 14:46:50,039 - root - INFO - Step 5 lr: 0.0008
[rank0]:[titan] 2025-06-04 14:46:50,040 - root - INFO - step:  5  loss:  1.6065  memory: 20.45GiB(43.17%)  tps: 1,971,288  tflops: 0.00  mfu: 0.00%

This same behaviour was causing the bug I described in #1213 . It was making it so, in the final step, we were going 1 over the max_steps expected, and thus setting the multiplicative lr adjustment factor to 0.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Jun 4, 2025
@@ -142,25 +142,26 @@ def linear_warmup_stable_decay(
to ensure the learning rate does not drop below this minimum value.
"""
warmup_stable_steps = warmup_steps + stable_steps
# if we are in the warmup phase, return the warmup progress
# if warmup_steps is 0, we will go to the next phase
if current_step < warmup_steps:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice fix, thank you @CarlosGomes98 ! I agree warmup_steps should not +1, and it's not 0-indexed. For current_step, because we always call self.step+=1 at the beginning of train(), current_step is also not 0-indexed. (https://github.com/pytorch/torchtitan/blob/refs/heads/main/torchtitan/train.py#L434)

@wwwjn wwwjn requested a review from tianyu-l June 4, 2025 16:27
Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @tianyu-l for lr_scheduler change

Copy link
Contributor

@fegin fegin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a discussion in https://github.com/pytorch/torchtitan/pull/1010. I'm wondering should we have unittests to verify different use cases for linear_warmup_stable_decay()? This function is a pure Python function which should not be hard to add some test cases to see if we fixes both issues (the issue from the checkpoint and the issue from the previous divided by zero).

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad that you found the bug!

I think the WSD lr scheduler is trickier than it sounds, due to ambiguous definition of warmup/stable/decay steps. See my earlier comment #1010 (review)

In particular, I believe your change would cause no learning to happen in the first iteration, unless lr_min is specified.

It is OK to me if there is a consensus this should be the default behavior.

elif lr_decay_type == "cosine":
curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress))
curr_adjustment = lr_min + (1 - lr_min) * curr_adjustment
return float(current_step / warmup_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From what I read, if warmup=4, stable=2, decay=4, train steps=10, what will happen on each step is:
(Here I'm assuming lr_decay_type == "linear" and lr_min=0)

step 1. lr=0, not learning
step 2. lr=0.25
step 3. lr=0.5
step 4. lr=0.75
step 5. lr=1
step 6. lr=1
step 7. lr=1
step 8. lr=0.75
step 9. lr=0.5
step 10. lr=0.25

This will cause no learning in the first step. Is this expected?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants