Skip to content

Adding Token Counter for Online RL #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

rithwik-db
Copy link
Collaborator

@rithwik-db rithwik-db commented Jul 15, 2025

We are updating the token counter to either be defined by the action_mask + prompt_len or using the len(sequences) when pad token is removed. I tested it out on compose-rl-grpo-test-5w6ZMR and both return the same output.

The bug that we encounter comes from the fact that we use a StreamingDataLoader without a bespoke get_num_tokens_in_batch fn (like the one llmfoundry uses here). Therefore, we are using the default token counting function in Composer here. As such, since we don't have input_ids in the batch, we end up just using max_seq_len * num_samples_in_batch to count the total number of tokens, which leads to the incorrect value that we're seeing in our tests.

@bowenyang008
Copy link
Collaborator

Thanks @rithwik-db it looks good to me but I would like to have reviews from experts. Additionally I found it pretty useful to log:

[RICKY] prompt_tokens: 1701
[RICKY] generated_tokens: 463

per iteration or per minibatch in each iteration. If no objection from other reviewers I would prefer adding this to the log as you have already done.

@bowenyang008
Copy link
Collaborator

also comparing your log and one of my log on main grpo-t2s-lr-1e-6-clip-5e-3-kl-1e-3-v7-9JJUW2

there is an order of magnitude diff
this branch:
Train throughput/tokens_per_sec: 12892.0938
main:
Train throughput/tokens_per_sec: 106238.8025
and the MFU using this branch is only 8%

@bowenyang008
Copy link
Collaborator

bowenyang008 commented Jul 16, 2025

I also benchmarked the dataset via this run: grpo-t2s-lr-1e-6-clip-5e-3-kl-1e-3-v7-xRNGNv and got this, which kind of explains the MFU issue, mean prompt + gen is likely in the range of 3000, while max_seq is close to 14K, so skewness in max and mean could result in significant padding or useless compute, e.g., if every batch is padded to nearly 14K, we would lose nearly 80% MFU, if we can recover this we will be back to the practical 40% MFU regime.

2025-07-16 07:11:16,638: rank0[823][MainThread]: INFO: compose_rl.algorithms.online.callback: number of prompts in full train_prompt_loader dataset: 9440
2025-07-16 07:11:16,644: rank0[823][MainThread]: INFO: compose_rl.algorithms.online.callback: global max prompt length in train dataset: 13776
2025-07-16 07:11:16,644: rank0[823][MainThread]: INFO: compose_rl.algorithms.online.callback: global mean prompt length in train dataset: 2595.5824152542373

@rithwik-db
Copy link
Collaborator Author

@bowenyang008, added the logs you mentioned here.

To answer:

so most the tokens are just padding?

Not exactly, the max amount of generated tokens is 2000 (which is what max_gen_len is) in the yaml above and so the amount of padding can't exceed that value. It's just that when we count tokens, we use Composer's counter which ends up just being max_seq_len * num_samples_in_batch which is defined here. Added this to the PR description as well.

@rithwik-db rithwik-db requested a review from bowenyang008 July 16, 2025 20:06
Copy link
Collaborator

@gupta-abhay gupta-abhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some comments

@rithwik-db rithwik-db requested a review from gupta-abhay July 17, 2025 18:23
Copy link
Collaborator

@gupta-abhay gupta-abhay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants