Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 94 additions & 20 deletions src/liger_kernel/chunked_loss/fused_linear_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,26 +285,32 @@ def _compute_chunk_loss(
):
"""Compute loss for a single chunk."""
# Get policy log probabilities using chunk_forward
log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(input_chunk, weight, bias=bias, temperature=temperature)
# Pass selected_token_ids to avoid materializing full vocab logits
per_token_logps, _ = LigerFusedLinearPPOBase.chunk_forward(
input_chunk, weight, bias=bias, temperature=temperature,
selected_token_ids=selected_token_ids_chunk
)
# per_token_logps is now [B, T] instead of [B, T, V]

# Get reference log probabilities if needed
ref_log_probs = None
ref_per_token_logps_computed = None
if use_ref_model and ref_per_token_logps_chunk is None:
with torch.no_grad():
ref_log_probs, _ = LigerFusedLinearPPOBase.chunk_forward(
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature
ref_per_token_logps_computed, _ = LigerFusedLinearPPOBase.chunk_forward(
ref_input_chunk, ref_weight, bias=ref_bias, temperature=temperature,
selected_token_ids=selected_token_ids_chunk
)

# Compute chunk loss and metrics using the provided loss function
# Note: ppo_loss_fn expects per-token logps, not full log_probs
chunk_loss, chunk_metrics = ppo_loss_fn(
log_probs=log_probs,
per_token_logps=per_token_logps,
selected_token_ids=selected_token_ids_chunk,
attention_mask=attention_mask_chunk,
advantages=advantages_chunk,
full_attention_mask=full_attention_mask,
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else None,
ref_per_token_logps=ref_per_token_logps_chunk.float() if ref_per_token_logps_chunk is not None else ref_per_token_logps_computed,
old_per_token_logps=old_per_token_logps_chunk.float() if old_per_token_logps_chunk is not None else None,
ref_log_probs=ref_log_probs, # used when ref_per_token_logps is None
epsilon_low=epsilon_low,
epsilon_high=epsilon_high,
beta=beta,
Expand All @@ -316,19 +322,87 @@ def _compute_chunk_loss(
return chunk_loss, chunk_metrics

@staticmethod
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0):
"""Forward pass computation for a single chunk without explicit reshaping."""
# Directly compute logits via batched matrix multiplication: [B, T, H] @ [H, V] -> [B, T, V]
logits = torch.matmul(input_chunk, weight.t())
if bias is not None:
logits = logits + bias # Broadcasts bias to [B, T, V]
if temperature != 1.0:
logits = logits / temperature

# Compute log probabilities using softmax over the last dimension
log_probs = F.log_softmax(logits.float(), dim=-1)

return log_probs, logits
def chunk_forward(input_chunk, weight, bias=None, temperature=1.0, selected_token_ids=None):
"""Forward pass computation without materializing full vocab logits.

Args:
input_chunk: [B, T, H] hidden states
weight: [V, H] weight matrix
bias: [V] optional bias
temperature: float for scaling logits
selected_token_ids: [B, T] token IDs to compute logprobs for (optional)

Returns:
log_probs: [B, T, V] or [B, T] if selected_token_ids provided
logits: None (not materialized to save memory)
"""
vocab_size = weight.shape[0]
vocab_chunk_size = min(4096, vocab_size) # Process vocab in chunks

# Compute log-sum-exp incrementally across vocab chunks
max_logit = torch.full(
input_chunk.shape[:-1],
float('-inf'),
device=input_chunk.device,
dtype=torch.float32
) # [B, T]
sum_exp = torch.zeros_like(max_logit) # [B, T]

# First pass: compute log-sum-exp over all vocab
for vocab_start in range(0, vocab_size, vocab_chunk_size):
vocab_end = min(vocab_start + vocab_chunk_size, vocab_size)
weight_chunk = weight[vocab_start:vocab_end] # [chunk_V, H]

# Compute logits for this vocab chunk: [B, T, H] @ [H, chunk_V] -> [B, T, chunk_V]
logits_chunk = torch.matmul(input_chunk, weight_chunk.t())
if bias is not None:
logits_chunk = logits_chunk + bias[vocab_start:vocab_end]
if temperature != 1.0:
logits_chunk = logits_chunk / temperature

logits_chunk = logits_chunk.float()

# Update running log-sum-exp
chunk_max = logits_chunk.max(dim=-1).values # [B, T]
new_max = torch.maximum(max_logit, chunk_max)

# Adjust sum_exp for new max
sum_exp = sum_exp * torch.exp(max_logit - new_max) + \
(logits_chunk - new_max.unsqueeze(-1)).exp().sum(dim=-1)
max_logit = new_max

# log_sum_exp = max_logit + log(sum_exp)
log_sum_exp = max_logit + torch.log(sum_exp) # [B, T]

if selected_token_ids is not None:
# Only compute logits for selected tokens
# Gather weight rows for selected tokens: [B, T, H]
selected_weights = weight[selected_token_ids] # [B, T, H]

# Compute selected logits: sum over hidden dim
selected_logits = (input_chunk * selected_weights).sum(dim=-1) # [B, T]

if bias is not None:
selected_bias = bias[selected_token_ids] # [B, T]
selected_logits = selected_logits + selected_bias

if temperature != 1.0:
selected_logits = selected_logits / temperature

# Compute log_probs for selected tokens only
selected_log_probs = selected_logits.float() - log_sum_exp # [B, T]

return selected_log_probs, None # No full logits materialized
else:
# Fallback: compute full logits (for backward compatibility)
logits = torch.matmul(input_chunk, weight.t())
if bias is not None:
logits = logits + bias
if temperature != 1.0:
logits = logits / temperature

log_probs = logits.float() - log_sum_exp.unsqueeze(-1)
return log_probs, logits

@staticmethod
def backward(ctx, grad_output, *grad_metrics):
Expand Down
23 changes: 10 additions & 13 deletions src/liger_kernel/chunked_loss/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,13 @@ def clip_coef_fn(coef, epsilon_low, epsilon_high):
class LigerFusedLinearGRPOFunction(LigerFusedLinearPPOBase):
@staticmethod
def ppo_loss_fn(
log_probs,
per_token_logps,
selected_token_ids,
attention_mask,
advantages,
full_attention_mask,
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
old_per_token_logps=None,
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
Expand All @@ -34,20 +33,18 @@ def ppo_loss_fn(
importance_sampling_level="token", # ["token", "sequence"] - new parameter for GSPO
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
) # (batch_size, seq_len)
"""GRPO Loss Function matching GRPOTrainer implementation.

Args:
per_token_logps: [B, T] log probabilities for selected tokens (already gathered)
selected_token_ids: [B, T] selected token IDs (not used anymore, kept for compatibility)
...
"""
# per_token_logps is now already gathered, no need for .gather() operation

# Get reference model probabilities
if ref_per_token_logps is None:
if ref_log_probs is not None:
with torch.no_grad():
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
)
else:
ref_per_token_logps = per_token_logps.detach()
ref_per_token_logps = per_token_logps.detach()

# Compute policy gradient loss with importance sampling ratio
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
Expand Down
166 changes: 166 additions & 0 deletions test/chunked_loss/test_grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,169 @@ def _reference_per_token_loss(
def _masked_mean(values, mask):
mask = mask.to(values.dtype)
return (values * mask).sum() / mask.sum().clamp(min=1.0)


@pytest.mark.parametrize(
"B, T, H, V",
[
(2, 4, 128, 4096),
(4, 8, 256, 8192),
],
)
@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dapo"])
@pytest.mark.parametrize("beta", [0.0, 0.04])
@pytest.mark.parametrize("temperature", [1.0, 0.9])
def test_chunked_vs_triton_grpo_loss(B, T, H, V, loss_type, beta, temperature):
"""
Test that chunked GRPO loss (LigerFusedLinearGRPOFunction) produces
the same results as Triton GRPO loss (triton_grpo_loss) for the same inputs.

This validates that the memory-optimized chunked implementation with
vocab chunking and selected-token-only computation produces identical
results to the Triton kernel implementation.
"""
pytest.importorskip("triton")
device = infer_device()
dtype = torch.float32

# Create shared inputs
_input = torch.randn(B, T, H, device=device, dtype=dtype)
input_chunked = _input.detach().clone().requires_grad_(True)

_weight = torch.randn(V, H, device=device, dtype=dtype)
weight_chunked = _weight.detach().clone().requires_grad_(True)

bias = None # Triton loss doesn't use bias

selected_token_ids = torch.randint(0, V, (B, T), device=device)
completion_mask = torch.ones(B, T, device=device, dtype=torch.long)
completion_mask[:, -1] = 0 # mask out last token for some sequences

advantages = torch.randn(B, device=device, dtype=dtype)

old_per_token_logps = torch.randn(B, T, device=device, dtype=dtype)
ref_per_token_logps = torch.randn(B, T, device=device, dtype=dtype) if beta != 0.0 else None

use_ref_model = beta != 0.0
# Only provide ref_input/ref_weight if ref_per_token_logps is not available
# For this test, we always provide ref_per_token_logps when beta > 0
ref_input = None
ref_weight = None
ref_bias = None

# ========================================
# Run chunked loss (our fixed implementation)
# ========================================
liger_loss_fn = LigerFusedLinearGRPOLoss(
beta=beta,
epsilon_low=0.2,
epsilon_high=0.2,
temperature=temperature,
use_ref_model=use_ref_model,
loss_type=loss_type,
max_completion_length=T,
importance_sampling_level="token",
chunk_size=1,
compiled=False, # Disable torch compile for testing
)

chunked_loss, chunked_aux = liger_loss_fn.forward(
input_chunked, # _input comes FIRST
weight_chunked, # lin_weight comes SECOND
selected_token_ids,
completion_mask.float(),
advantages,
bias,
ref_per_token_logps,
old_per_token_logps,
ref_input,
ref_weight,
ref_bias,
)

# ========================================
# Prepare inputs for Triton loss
# ========================================
# Triton expects logits of shape [B, T+1, V]
# Compute logits from hidden states
with torch.no_grad():
logits_for_triton = torch.matmul(_input, _weight.t()) # [B, T, V]
# Pad with zeros to make it [B, T+1, V]
logits_for_triton = F.pad(logits_for_triton, (0, 0, 0, 1)) # pad sequence dim
logits_for_triton = logits_for_triton.contiguous()
logits_for_triton = logits_for_triton / temperature

# ========================================
# Run Triton loss
# ========================================
triton_per_token_loss, triton_per_token_kl, triton_is_clipped = triton_grpo_loss(
logits=logits_for_triton,
old_logp=old_per_token_logps,
ref_logp=ref_per_token_logps,
completion_ids=selected_token_ids,
advantages=advantages,
completion_mask=completion_mask,
temperature=1.0, # already applied to logits
beta=beta,
eps_low=0.2,
eps_high=0.2,
inplace=False,
loss_type=loss_type,
max_completion_length=T,
importance_sampling_level="token",
reduce=False,
)

# ========================================
# Compare results
# ========================================
# Extract chunked results
chunked_per_token_loss = chunked_aux[6] # per_token_loss
chunked_kl = chunked_aux[7] if beta != 0.0 else None # per_token_kl
chunked_is_clipped = chunked_aux[8] # is_clipped

# Compare per-token losses
mask = completion_mask.float()
mask_bool = mask.bool()

# Losses should match
assert_verbose_allclose(
chunked_per_token_loss * mask,
triton_per_token_loss,
atol=1e-3,
rtol=1e-2,
msg=f"Per-token losses don't match for {loss_type} with beta={beta}, temperature={temperature}"
)

# KL divergence should match if beta > 0
if beta != 0.0:
assert chunked_kl is not None
assert triton_per_token_kl is not None
assert_verbose_allclose(
chunked_kl * mask,
triton_per_token_kl,
atol=1e-3,
rtol=1e-2,
msg=f"KL divergences don't match for {loss_type} with beta={beta}"
)
else:
assert chunked_kl is None
assert triton_per_token_kl is None

# Clipping indicators should match
assert torch.equal(
chunked_is_clipped.bool()[mask_bool],
triton_is_clipped.bool()[mask_bool]
), f"Clipping indicators don't match for {loss_type} with beta={beta}"

# Compare reduced losses
chunked_reduced_loss = chunked_loss
triton_reduced_loss = _reduce_grpo_loss(triton_per_token_loss, completion_mask, loss_type, T)

assert_verbose_allclose(
chunked_reduced_loss,
triton_reduced_loss,
atol=1e-3,
rtol=1e-2,
msg=f"Reduced losses don't match for {loss_type} with beta={beta}, temperature={temperature}"
)
Loading