-
Notifications
You must be signed in to change notification settings - Fork 11
Adding Token Counter for Online RL #110
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thanks @rithwik-db it looks good to me but I would like to have reviews from experts. Additionally I found it pretty useful to log:
per iteration or per minibatch in each iteration. If no objection from other reviewers I would prefer adding this to the log as you have already done. |
also comparing your log and one of my log on main grpo-t2s-lr-1e-6-clip-5e-3-kl-1e-3-v7-9JJUW2 there is an order of magnitude diff |
I also benchmarked the dataset via this run: grpo-t2s-lr-1e-6-clip-5e-3-kl-1e-3-v7-xRNGNv and got this, which kind of explains the MFU issue, mean prompt + gen is likely in the range of 3000, while max_seq is close to 14K, so skewness in max and mean could result in significant padding or useless compute, e.g., if every batch is padded to nearly 14K, we would lose nearly 80% MFU, if we can recover this we will be back to the practical 40% MFU regime.
|
@bowenyang008, added the logs you mentioned here. To answer:
Not exactly, the max amount of generated tokens is 2000 (which is what |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
78e8057
to
eb71a3d
Compare
eb71a3d
to
f325d3e
Compare
We are updating the token counter to either be defined by the
action_mask
+prompt_len
or using thelen(sequences)
when pad token is removed. I tested it out oncompose-rl-grpo-test-5w6ZMR
and both return the same output.The bug that we encounter comes from the fact that we use a StreamingDataLoader without a bespoke
get_num_tokens_in_batch
fn (like the one llmfoundry uses here). Therefore, we are using the default token counting function in Composer here. As such, since we don't haveinput_ids
in the batch, we end up just usingmax_seq_len * num_samples_in_batch
to count the total number of tokens, which leads to the incorrect value that we're seeing in our tests.