Skip to content

LigerFusedLinearGRPOLoss produces ~100x larger grad_norm than TRL's non-Liger-Kernel path due to missing vLLM IS correction and other differences #1082

@yukiu00

Description

@yukiu00

Summary

When using use_liger_kernel=True with TRL v0.27.2's GRPOTrainer and vLLM, grad_norm is ~100x larger than the non-Liger-Kernel path. This is primarily caused by the vLLM importance sampling correction not being applied in the Liger-Kernel loss path. Several other silent differences also exist.

List of differences (TRL v0.27.2)

# Difference TRL (non-Liger-Kernel) Liger-Kernel Impact Silent?
1 vLLM IS correction per_token_loss *= importance_sampling_ratio (L2351-L2352) Not applied ~100-300x (primary cause) Yes
2 dapo/cispo normalizer num_items_in_batch / num_processes (total tokens across entire generation batch) (L2371) all_reduce(sum(attention_mask)) / world_size (current micro-batch only), then divided by current_gradient_accumulation_steps (Liger-Kernel L116-117, TRL L2194) Several-fold depending on length variance Yes
3 tool_mask completion_mask * tool_mask (L2243) completion_mask only (TRL L2178) Proportional to tool token ratio Yes
4 use_bias_correction_kl per_token_kl *= coef_1 (L2315-L2316) Not implemented KL term only Yes
5 delta (ratio clamping) coef_1 = clamp(coef_1, max=delta) (L2326-L2327) Not implemented Changes clipping behavior Yes
6 off_policy_mask, top_entropy_quantile, sequence-level IS, sapo Supported Not supported - No (raises error)

Where to fix

# Difference Where to fix Notes
1 vLLM IS correction Liger-Kernel + TRL Liger-Kernel: add parameter to apply per_token_loss *= is_ratio before reduction. TRL: pass it from compute_liger_loss
2 dapo/cispo normalizer Liger-Kernel + TRL Liger-Kernel: accept an external normalizer. TRL: pass num_items_in_batch
3 tool_mask TRL only Pass completion_mask * tool_mask in compute_liger_loss
4 use_bias_correction_kl Liger-Kernel + TRL Liger-Kernel: add flag to ppo_loss_fn. TRL: pass it
5 delta (ratio clamping) Liger-Kernel + TRL Liger-Kernel: add clamp argument for coef_1. TRL: pass it

Observed behavior

With DAPO loss on a large model using vLLM in GRPOTrainer:

Metric (step 1) Liger-Kernel (use_liger_kernel=True) Non-Liger-Kernel
grad_norm 31.09 0.29
loss 0.0187 -0.0001

The IS ratio mean was ~0.003–0.01, so the non-Liger-Kernel path scales down the loss by that factor, while the Liger-Kernel path passes it through as-is.

Conclusion

The primary cause is # 1: vLLM IS correction is not applied in Liger-Kernel. Since the IS ratio must be multiplied before reduction, it cannot be fixed on the TRL side alone — a Liger-Kernel change is necessary. Only # 3 can be resolved with a one-line fix on the TRL side.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions