Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Nov 26, 2025

Summary

Updating the forward pass to compute only the required per-token log probabilities, simplifying the loss function interface, and adding comprehensive tests to ensure correctness against the Triton implementation:

  • The chunk_forward method in fused_linear_ppo.py now computes log probabilities only for selected tokens (using selected_token_ids), avoiding allocation of large [B, T, V] tensors and instead returning [B, T] tensors for per-token log probabilities. This greatly reduces memory usage, especially for large vocabularies.
  • The loss computation in _compute_chunk_loss is updated to use these per-token log probabilities directly, and the interface for the loss function is changed accordingly (from log_probs to per_token_logps). [1] [2]

Simplification and correctness improvements:

  • The ppo_loss_fn in grpo_loss.py is simplified: it now expects pre-gathered per-token log probabilities, removing the need for an internal .gather() operation and unnecessary handling of full log probability tensors.
  • Redundant arguments and code paths for handling full vocabulary log probabilities are removed, further streamlining the code.

Testing and validation:

  • A comprehensive test, test_chunked_vs_triton_grpo_loss, is added to ensure that the chunked, memory-optimized loss matches the Triton kernel implementation across a range of configurations, including different batch sizes, sequence lengths, hidden sizes, vocab sizes, loss types, and hyperparameters. This test checks per-token losses, KL divergences, clipping indicators, and reduced losses for correctness.

Testing Done

  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant