Skip to content

vLLM not computing correct log probs when using GRPO with =/=1 temperature #4159

@YonatanGideoni

Description

@YonatanGideoni

Reproduction

from datasets import load_dataset
from trl import GRPOConfig, GRPOTrainer


class DebugGRPOTrainer(GRPOTrainer):
    def _generate_and_score_completions(self, inputs):
        batch = super()._generate_and_score_completions(inputs)
        # importance_sampling_ratio = torch.exp(old_logps - vllm_logps)
        print(f"[DEBUG] first-token IS ratio: {batch['importance_sampling_ratio'][:, 0].detach().cpu()}")
        return batch


def reward_fn(completions, **kwargs):  # dummy
    return [float(len(c)) for c in completions]


args = GRPOConfig(
    output_dir="out",
    per_device_train_batch_size=2,
    num_generations=2,
    num_iterations=1,
    max_prompt_length=64,
    max_completion_length=8,
    steps_per_generation=1,
    temperature=0.6,
    use_vllm=True,
    vllm_mode="colocate",
    logging_steps=-1,
    report_to="none",
)

ds = load_dataset("trl-lib/tldr", split="train[:64]")
trainer = DebugGRPOTrainer(
    model="HuggingFaceTB/SmolLM-135M-Instruct",
    reward_funcs=[reward_fn],
    args=args,
    train_dataset=ds,
)

trainer.train()

outputs:

...
[DEBUG] first-token IS ratio: tensor([2., 2.])
...

For reference, when using T=1 the printed ratio is typically in [0.85, 1.15]. First token would have the highest deviation as perplexity is lower further into a sequence. It's almost always 2 here as it's using truncated importance sampling, so in practice the deviation is much bigger.

When using a non-1 temperature the vLLM backend by default doesn't apply temperature scaling to the logprobs it returns so when using TIS to account for the rollout inference vs training mismatch (see https://fengyao.notion.site/off-policy-rl , this is by default enabled in TRL) there's a huge mismatch in the log probs as some are temperature scaled while some aren't. I've seen this cause severe training instability in some runs. The fix is really simple - for vLLM>=0.10.2 in the GRPOTrainer init add logprobs_mode="processed_logprobs" to the vLLM init. I've only tested this in colocate mode but imagine it'd be similar when running vLLM on a separate server.

System Info

  • Platform: Linux-6.8.0-79-generic-x86_64-with-glibc2.35
  • Python version: 3.10.18
  • TRL version: 0.23.0
  • PyTorch version: 2.8.0
  • accelerator(s): NVIDIA GeForce RTX 4090
  • Transformers version: 4.56.1
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • Datasets version: 4.0.0
  • HF Hub version: 0.34.4
  • bitsandbytes version: not installed
  • DeepSpeed version: not installed
  • Diffusers version: not installed
  • Liger-Kernel version: not installed
  • LLM-Blender version: not installed
  • OpenAI version: 1.107.2
  • PEFT version: not installed
  • vLLM version: 0.10.2

Checklist

  • I have checked that my issue isn't already filed (see open issues)
  • I have included my system information
  • Any code provided is minimal, complete, and reproducible (more on MREs)
  • Any code provided is properly formatted in code blocks, (no screenshot, more on code blocks)
  • Any traceback provided is complete

Metadata

Metadata

Assignees

No one assigned

    Labels

    🏋 GRPORelated to GRPO🐛 bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions