v0.16.0
Major and breaking
🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication
Previously, vLLM could only be used by dedicating a single GPU, preventing both the scalability benefits of vLLM and multi-node training. This limitation has now been removed!
GRPO can now scale efficiently with models exceeding 70B parameters, supporting multi-node training with super-fast performance.
To take advantage of this, simply launch a vLLM server using the following command:
trl vllm-serve --model <model_name> --tensor_parallel_size <tp_size>
Then, start GRPO training with use_vllm=True
.
Below is a comparison of GRPO throughput with and without vLLM, across different TP values and model sizes.
@binary-husky and @qgallouedec in #3094
🐦🔥 6x faster GRPO with multi-step optimization
This release introduces the multi-step trick, which allows for the reuse of generated data across multiple steps, speeding up the training process.
To support this, we've implemented importance sampling and clipping logic. This enhancement should lead to significant improvements in training speed.

To use it, simply set num_iterations
to a value greater than 1.
training_args = GRPOConfig(..., num_iterations=4)
by @qgallouedec in #2899
🌍 Use global normalization in GRPO
As demonstrated in Dr GRPO, sequence-level normalization can introduce a response level length bias.
To address this, we have now switched to normalizing the loss and by the total number of tokens in the batch, ensuring more consistent and unbiased training.
- loss = ((per_token_loss * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
+ loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
by @edbeeching in #2881
⚖️ Add option not to scale rewards
As demonstrated in Dr GRPO, scaling rewards can introduce a question-level difficulty bias. To address this, we have now added an option to disable reward scaling in GRPO.
training_args = GRPOConfig(..., scale_rewards=False)
advantages = rewards - mean_grouped_rewards
- advantages = advantages / std_grouped_rewards
+ if self.args.scale_rewards:
+ advantages = advantages / std_grouped_rewards
it's likely that we'll make this (scale_rewards=False
) the default behavior in the future.
by @qgallouedec in #3135
🤸♀️ Domain-specific rewards in GRPO
When optimizing across multiple domains, not all reward functions are relevant for every sample. For example, a math verifier's reward does not apply to grammar samples, and a grammar verifier's reward does not apply to math samples.
It is now possible to return None
for rewards that do not make sense for a given sample. For instance, when the domain is specified in a column like domain
, you can implement it as follows:
def math_reward(completions, domain, **kwargs):
rewards = []
for completion, dom in zip(completions, domain):
if dom == "math":
rewards.append(verify(completion))
else:
rewards.append(None)
return rewards
This allows for more domain-specific reward handling, ensuring that irrelevant rewards are ignored and don’t interfere with optimization.
by @shirinyamani in #3079
🍃 Do not load reference model when beta == 0.0
It has been observed that not minimizing the KL divergence between the trained model and the reference model can still yield good results, while significantly reducing memory usage and compute. This is because there is no need to store the reference model in memory or perform a forward pass for it.
When beta
is set to 0.0
, the reference model is not loaded, and the KL divergence is not computed, leading to savings in both time and memory.
training_args = GRPOConfig(..., beta=0.0)
🕊️ Padding-free for SFT
Padding-free batching is an alternative approach to packing for reducing memory usage. In this method, a batch is first sampled and then flattened into a single sequence, avoiding padding. Unlike packing, which can result in incomplete sequences by combining parts of different samples, padding-free batching ensures that all sequences remain complete and intact.
To enable padding-free batching in SFT, simply set padding_free=True
in the SFTConfig
, and make sure to use flash_attention2
as the attention implementation.
training_args = SFTConfig(..., padding_free=True, model_init_kwargs={"attn_implementation": "flash_attention2"})
by @qgallouedec in #3076
🎬 Clip Higher for Better Exploration
As outlined in the DAPO paper, increasing the upper bound epsilon leads to higher entropy during generation, promoting better exploration. To enable this, we’ve added support for adjusting the upper bound epsilon directly in the default GRPO trainer.
training_args = GRPOConfig(epsilon_high=0.28)
by @shirinyamani in #3118
Bug fixes
- 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 🐯 Fix LigerKernel for SFTTrainer by @lewtun @kashif and @qgallouedec in #2874, #2940 and #2949
- 🫔 [GRPO] Pass wrapped model to
unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility by @kiddj in #2871 - 🛣️
inference_mode
tono_grad
when computingold_per_token_logps
by @qgallouedec in #2987 - 🏊 [SFT] Compatibility with padding free and iterable dataset by @qgallouedec in #3053
- Fixing JSD loss computation in GKDTrainer as per definition by @abhigoyal1997 in #3043
Minor
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW and @DanFosing in #2863 and #2939
- ✨ Add vLLM guided decoding support to GRPO Trainer by @kldzj in #2811
- 🩳
max_seq_length
tomax_length
by @qgallouedec in #2895 and #2947 - Optimize vllm num_generations by @edbeeching in #2855
- 📍 [GRPO] add gradient_checkpointing by @kashif in #2848
- 🪪 Adds profiling decorators for GRPOTrainer by @edbeeching in #2889 and #2975
- 🐈 Bye bye chat by @qgallouedec in #2934
- 📇 GRPO: print completions to console and update docs by @nopepper in #2951
- 👧🏽 Adding DoRA support to model config by @nbasyl in #2974
- 🧗 Add GRPO Trainer support for third-party accelerators by @ji-huazhong in #2836
- 🪙 [SFT] Log
num_tokens
and some logging fixes by @qgallouedec in #3006 - 🌡️ Fix temperature inconsistency in GRPO trainer by @Aladoro in #3029
- ⛔ Add EOS token to processed input in SFT by @qgallouedec in #3091
- ⚡ Pack 300 times faster, truncate 100 times faster by @mariosasko in #3009
What's Changed
- [SFT] fix check for AutoLigerKernelForCausalLM by @kashif in #2874
- 🆙 Bump vLLM min version to 0.7.2 by @edbeeching in #2860
- [GRPO] Fix loss normalization by @edbeeching in #2881
- 💬 Add
maybe_convert_to_chatml
map for conversational datasets in SFT by @kashif in #2862 - 🧶 [GRPO][vLLM + LoRA] Move unmerge of PEFT model after weight loading by @XZ-X in #2873
- 🍟 [SFT] Handles the dataset if it has been preprocessed by @BenasdTW in #2863
- Optimize vllm num_generations by @edbeeching in #2855
- 🪂 Don't gather logits in SFT to avoid hanging by @qgallouedec in #2890
- ✨ Add vLLM guided decoding support to GRPO Trainer by @kldzj in #2811
- ⚰️ Remove deprecated by @qgallouedec in #2894
- 🩳
max_seq_length
tomax_length
by @qgallouedec in #2895 - 🍃 GRPO - Do not load reference model when beta == 0 by @ingambe in #2806
- 📍 [GRPO] add gradient_checkpointing by @kashif in #2848
- 🪪 Adds profiling decorators for GRPOTrainer by @edbeeching in #2889
- 🐦🔥 6x faster GRPO with multi-step optimization by @qgallouedec in #2899
- 🔹 Fix: Miscalculated mask shape in comments by @linkedlist771 in #2925
- 🤖 Style bot by @qgallouedec in #2935
- 🧼 Upgrade ruff by @qgallouedec in #2938
- 🐈 Bye bye chat by @qgallouedec in #2934
- ♻️ Fix caching in SFT by @qgallouedec in #2945
- 📋 Add vLLM version to environment printout by @qgallouedec in #2946
- ☠️ Update
max_seq_length
tomax_length
inSFTConfig
by @qgallouedec in #2947 - 🐯 Fix LigerKernel for SFTTrainer by @lewtun in #2940
- ✋ Prevent applying the chat template to tokenized datasets by @DanFosing in #2939
- 📇 GRPO: print completions to console and update docs by @nopepper in #2951
- ↩️ Fix typo in TextEnvironment init param, should be
max_tool_response
by @shenxiangzhuang in #2921 - 🗿 Updated DPO default values for alpha and tau by @Ishan-Kumar2 in #2918
- 📌 Pin liger-kernel and vLLM by @qgallouedec in #2952
- ⏪ Parameterize
enable_prefix_caching
by @ji-huazhong in #2900 - 🔢 Fix GRPO doc about
num_iterations
by @qgallouedec in #2966 - Update grpo_trainer.py by @tpoisonooo in #2973
- 👧🏽 Adding DoRA support to model config by @nbasyl in #2974
- 🧗 Add GRPO Trainer support for third-party accelerators by @ji-huazhong in #2836
- 🕸 Add distributing training guide by @qgallouedec in #2956
- 👂 Update learning rate doc in
KTOConfig
by @sileod in #2912 - 🌌 Fix logits computation in trainer prediction step by @logicaltrojan in #2969
- 🪪 Adds a more fine-grained profiling context by @edbeeching in #2975
- 🧬 Fix typo in grpo_trainer.py by @congchan in #2988
- 📜 Update README and doc index by @qgallouedec in #2986
- 📑 Fix logged metrics for KTO by @vaibhavjindal in #2982
- ⚰️ Deprecate liger-kernel by @qgallouedec in #2949
- 🔍 Update GRPO config documentation for beta parameter stability by @nopepper in #2992
- 🫔 [GRPO] Pass wrapped model to
unwrap_model_for_generation
for DeepSpeed Stage-3 compatibility by @kiddj in #2871 - 🛣️
inference_mode
tono_grad
when computingold_per_token_logps
by @qgallouedec in #2987 - 🚀 DeepSpeed integration documentation by @qgallouedec in #2993
- Update pr_style_bot.yml by @qgallouedec in #3003
- 🪙 [SFT] Log
num_tokens
and some logging fixes by @qgallouedec in #3006 - Improve ci by @paulinebm in #3007
- ✌️Remove double compute of sum in SFTTrainer by @lexasub in #3001
- 📚 Update customization and distributing training documentation by @qgallouedec in #2991
- 🌍 Use global normalization for KL logging (to match normalization for loss) by @tchang1997 in #3004
- 🗜️ Loosened tokenizer type hint on
apply_chat_template
by @jamesbraza in #3005 - 🎲 Add support for additional generation kwargs in GRPO Trainer by @nopepper in #2989
- 🚀 Supporting
deepspeed>=0.16.4
's rename by @jamesbraza in #2963 - 🌡️ Fix temperature inconsistency in GRPO trainer by @Aladoro in #3029
- 🏁 Passing custom BOS/EOS token to
GPROTrainer.generation_config
by @jamesbraza in #3046 - 💠 Fixing
SFTTrainer.compute_loss
crash withaccelerate
by @jamesbraza in #3048 - 👯 [GRPO] Relax the assumption that prompts are unique within a batch by @qgallouedec in #3052
- [GRPO] use argument names with processing_class by @kashif in #3062
- 🦥 Fixed
SFTTrainer.compute_loss
hang from #3048's PR comments by @jamesbraza in #3056 - 🏊 [SFT] Compatibility with padding free and iterable dataset by @qgallouedec in #3053
- Fixing JSD loss computation in GKDTrainer as per definition by @abhigoyal1997 in #3043
- 🎭 Minor spelling fix in documentation (caracteres -> characters) by @esnible in #3074
- 💎 Gemma 3 SFT example on Codeforces dataset by @qgallouedec in #3070
- 🫣 [GRPO] add cache_implementation option in GRPO by @kashif in #3075
- ⛔ Add EOS token to processed input in SFT by @qgallouedec in #3091
- 🕊️ Padding-free for SFT by @qgallouedec in #3076
- add "_prepare_fsdp" for DPOTrainer by @faaany in #2539
- Use main process for dataset.map by @lewtun in #3106
- Flexible_reward by @shirinyamani in #3079
- 🎬 Clip higher by @shirinyamani in #3118
- 🚀 Scaling GRPO to 70B+ Models and Multi-Node Training with vLLM Server & NCCL Communication by @binary-husky in #3094
- ⚡ Pack 300 times faster, truncate 100 times faster by @mariosasko in #3009
- ☎️ Documentation for disable gathering of model weights for generation in DeepSpeed ZeRO-3 by @qgallouedec in #3136
- ⚖️ Add option not to scale rewards (Dr. GRPO) by @qgallouedec in #3135
- Release: v0.16 by @qgallouedec in #3137
New Contributors
- @XZ-X made their first contribution in #2873
- @BenasdTW made their first contribution in #2863
- @kldzj made their first contribution in #2811
- @ingambe made their first contribution in #2806
- @linkedlist771 made their first contribution in #2925
- @DanFosing made their first contribution in #2939
- @nopepper made their first contribution in #2951
- @shenxiangzhuang made their first contribution in #2921
- @Ishan-Kumar2 made their first contribution in #2918
- @tpoisonooo made their first contribution in #2973
- @nbasyl made their first contribution in #2974
- @sileod made their first contribution in #2912
- @logicaltrojan made their first contribution in #2969
- @congchan made their first contribution in #2988
- @vaibhavjindal made their first contribution in #2982
- @kiddj made their first contribution in #2871
- @paulinebm made their first contribution in #3007
- @lexasub made their first contribution in #3001
- @tchang1997 made their first contribution in #3004
- @Aladoro made their first contribution in #3029
- @abhigoyal1997 made their first contribution in #3043
- @esnible made their first contribution in #3074
- @mariosasko made their first contribution in #3009
Full Changelog: v0.15.0...v0.16.0