generated from fastai/nbdev_template
-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Closed
Labels
⚡ PEFTRelated to PEFTRelated to PEFT🏋 GRPORelated to GRPORelated to GRPO🐛 bugSomething isn't workingSomething isn't working
Description
CI fails for Slow tests: https://github.com/huggingface/trl/actions/runs/18106208226/job/51521315792
ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_0_trl_internal_testing_tiny_LlamaForCausalLM_3_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_1_trl_internal_testing_tiny_MistralForCausalLM_0_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_and_peft_0_trl_internal_testing_tiny_LlamaForCausalLM_3_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_and_peft_1_trl_internal_testing_tiny_MistralForCausalLM_0_2 - ValueError: Unknown loss type: dapo
Traceback:
tests/slow/test_grpo_slow.py:102: in test_training_with_liger_grpo_loss
trainer.train()
.venv/lib/python3.11/site-packages/transformers/trainer.py:2328: in train
return inner_training_loop(
.venv/lib/python3.11/site-packages/transformers/trainer.py:2672: in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/trainer.py:4009: in training_step
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1675: in compute_loss
return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/models/utils.py:464: in __call__
wrapper_output = wrapper_module(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/models/utils.py:457: in wrapped_forward
out = method(*_args, **_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1647: in compute_liger_loss
loss, metrics = self.liger_grpo_loss(
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:249: in forward
return LigerFusedLinearGRPOFunction.apply(
.venv/lib/python3.11/site-packages/torch/autograd/function.py:576: in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:142: in forward
return super().forward(
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:219: in forward
accumulate_chunk(
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:132: in accumulate_chunk
(chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:736: in compile_wrapper
return fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:111: in fused_fwd_bwd
return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
.venv/lib/python3.11/site-packages/torch/_functorch/apis.py:441: in wrapper
return eager_transforms.grad_and_value_impl(
.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py:48: in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py:1365: in grad_and_value_impl
output = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:281: in _compute_chunk_loss
chunk_loss, chunk_metrics = ppo_loss_fn(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
log_probs = GradTrackingTensor(lvl=-2, value=
tensor([[[-11.7593, -11.7727, -11.7746, ..., -11.7516, -11.7612, -11.7610],
...7493],
[-11.7659, -11.7561, -11.7638, ..., -11.7372, -11.7693, -11.7412]]],
device='cuda:0')
)
selected_token_ids = GradTrackingTensor(lvl=-2, value=
tensor([[ 90527, 55865, 33436, 109546, 75008, 54736, 104260, 111418, 62094,...7312, 36409, 10994, 84125, 21704, 11711, 72327, 66627, 216,
5997, 63032]], device='cuda:0')
)
attention_mask = GradTrackingTensor(lvl=-2, value=
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,... 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)
)
advantages = GradTrackingTensor(lvl=-2, value=
tensor([0.7245], device='cuda:0')
)
full_attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1... 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)
ref_per_token_logps = GradTrackingTensor(lvl=-2, value=
tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
-11.7896, -11.7596]], device='cuda:0')
)
old_per_token_logps = GradTrackingTensor(lvl=-2, value=
tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
-11.7896, -11.7596]], device='cuda:0')
)
ref_log_probs = None, epsilon_low = 0.2, epsilon_high = 0.2, beta = 0.0
loss_type = 'dapo', max_completion_length = 128, kwargs = {}
per_token_logps = GradTrackingTensor(lvl=-2, value=
tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
-11.7896, -11.7596]], device='cuda:0')
)
coef_1 = GradTrackingTensor(lvl=-2, value=
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,... 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.]], device='cuda:0')
)
coef_2 = GradTrackingTensor(lvl=-2, value=
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,... 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1.]], device='cuda:0')
)
per_token_loss1 = GradTrackingTensor(lvl=-2, value=
tensor([[0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,...7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,
0.7245, 0.7245]], device='cuda:0')
)
per_token_loss2 = GradTrackingTensor(lvl=-2, value=
tensor([[0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,...7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,
0.7245, 0.7245]], device='cuda:0')
)
per_token_loss = GradTrackingTensor(lvl=-2, value=
tensor([[-0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245,...5,
-0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245]],
device='cuda:0')
)
@staticmethod
def ppo_loss_fn(
log_probs,
selected_token_ids,
attention_mask,
advantages,
full_attention_mask,
ref_per_token_logps=None, # shape: [chunk_size, seq_len]
old_per_token_logps=None,
ref_log_probs=None, # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
epsilon_low=0.2,
epsilon_high=0.2,
beta=0.04,
loss_type="bnpo", # ["grpo", "bnpo", "dr_grpo"]
max_completion_length=None, # Required for dr_grpo
**kwargs,
):
"""GRPO Loss Function matching GRPOTrainer implementation."""
per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
) # (batch_size, seq_len)
# Get reference model probabilities
if ref_per_token_logps is None:
if ref_log_probs is not None:
with torch.no_grad():
ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
-1
)
else:
ref_per_token_logps = per_token_logps.detach()
# Compute policy gradient loss with importance sampling ratio
old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if beta != 0.0:
# Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
# Combine losses
per_token_loss = per_token_loss + beta * kl_div
# Note: We normalize by the number of tokens in the batch (using full_attention_mask),
# which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
# and TRL GRPO implementation
# (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
if loss_type == "grpo":
# Average per-sequence loss
loss = (
(per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
).sum() / full_attention_mask.shape[0]
elif loss_type == "bnpo":
# Batch Normalized Per-token loss (original implementation)
loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
elif loss_type == "dr_grpo":
# Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
if max_completion_length is None:
raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
else:
> raise ValueError(f"Unknown loss type: {loss_type}")
E ValueError: Unknown loss type: dapo
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:82: ValueError
Metadata
Metadata
Assignees
Labels
⚡ PEFTRelated to PEFTRelated to PEFT🏋 GRPORelated to GRPORelated to GRPO🐛 bugSomething isn't workingSomething isn't working