Skip to content

CI fails: ValueError: Unknown loss type: dapo #4172

@albertvillanova

Description

@albertvillanova

CI fails for Slow tests: https://github.com/huggingface/trl/actions/runs/18106208226/job/51521315792

ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_0_trl_internal_testing_tiny_LlamaForCausalLM_3_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_1_trl_internal_testing_tiny_MistralForCausalLM_0_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_and_peft_0_trl_internal_testing_tiny_LlamaForCausalLM_3_2 - ValueError: Unknown loss type: dapo
FAILED tests/slow/test_grpo_slow.py::GRPOTrainerSlowTester::test_training_with_liger_grpo_loss_and_peft_1_trl_internal_testing_tiny_MistralForCausalLM_0_2 - ValueError: Unknown loss type: dapo

Traceback:

tests/slow/test_grpo_slow.py:102: in test_training_with_liger_grpo_loss
    trainer.train()
.venv/lib/python3.11/site-packages/transformers/trainer.py:2328: in train
    return inner_training_loop(
.venv/lib/python3.11/site-packages/transformers/trainer.py:2672: in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/transformers/trainer.py:4009: in training_step
    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/extras/profiling.py:98: in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1675: in compute_loss
    return self._forward_redirection(model, unwrapped_model, self.compute_liger_loss, unwrapped_model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/models/utils.py:464: in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
trl/models/utils.py:457: in wrapped_forward
    out = method(*_args, **_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
trl/trainer/grpo_trainer.py:1647: in compute_liger_loss
    loss, metrics = self.liger_grpo_loss(
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784: in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:249: in forward
    return LigerFusedLinearGRPOFunction.apply(
.venv/lib/python3.11/site-packages/torch/autograd/function.py:576: in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:142: in forward
    return super().forward(
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:219: in forward
    accumulate_chunk(
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:132: in accumulate_chunk
    (chunk_grad_input, chunk_grad_weight, *chunk_grad_bias), (chunk_loss, chunk_metrics) = fused_fwd_bwd(
.venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py:736: in compile_wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:111: in fused_fwd_bwd
    return torch.func.grad_and_value(compute_loss, argnums=argnums, has_aux=True)(
.venv/lib/python3.11/site-packages/torch/_functorch/apis.py:441: in wrapper
    return eager_transforms.grad_and_value_impl(
.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py:48: in fn
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/torch/_functorch/eager_transforms.py:1365: in grad_and_value_impl
    output = func(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^
.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/fused_linear_ppo.py:281: in _compute_chunk_loss
    chunk_loss, chunk_metrics = ppo_loss_fn(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

log_probs = GradTrackingTensor(lvl=-2, value=
    tensor([[[-11.7593, -11.7727, -11.7746,  ..., -11.7516, -11.7612, -11.7610],
   ...7493],
             [-11.7659, -11.7561, -11.7638,  ..., -11.7372, -11.7693, -11.7412]]],
           device='cuda:0')
)
selected_token_ids = GradTrackingTensor(lvl=-2, value=
    tensor([[ 90527,  55865,  33436, 109546,  75008,  54736, 104260, 111418,  62094,...7312,  36409,  10994,  84125,  21704,  11711,  72327,  66627,    216,
               5997,  63032]], device='cuda:0')
)
attention_mask = GradTrackingTensor(lvl=-2, value=
    tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,... 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
             1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)
)
advantages = GradTrackingTensor(lvl=-2, value=
    tensor([0.7245], device='cuda:0')
)
full_attention_mask = tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1, 1... 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0', dtype=torch.int32)
ref_per_token_logps = GradTrackingTensor(lvl=-2, value=
    tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
 ...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
             -11.7896, -11.7596]], device='cuda:0')
)
old_per_token_logps = GradTrackingTensor(lvl=-2, value=
    tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
 ...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
             -11.7896, -11.7596]], device='cuda:0')
)
ref_log_probs = None, epsilon_low = 0.2, epsilon_high = 0.2, beta = 0.0
loss_type = 'dapo', max_completion_length = 128, kwargs = {}
per_token_logps = GradTrackingTensor(lvl=-2, value=
    tensor([[-11.7450, -11.7677, -11.7790, -11.7808, -11.7506, -11.7674, -11.7460,
 ...7795, -11.7718, -11.7685, -11.7412, -11.7400, -11.7612, -11.7550,
             -11.7896, -11.7596]], device='cuda:0')
)
coef_1 = GradTrackingTensor(lvl=-2, value=
    tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,...      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
             1., 1.]], device='cuda:0')
)
coef_2 = GradTrackingTensor(lvl=-2, value=
    tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,...      1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
             1., 1.]], device='cuda:0')
)
per_token_loss1 = GradTrackingTensor(lvl=-2, value=
    tensor([[0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,...7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,
             0.7245, 0.7245]], device='cuda:0')
)
per_token_loss2 = GradTrackingTensor(lvl=-2, value=
    tensor([[0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,...7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245, 0.7245,
             0.7245, 0.7245]], device='cuda:0')
)
per_token_loss = GradTrackingTensor(lvl=-2, value=
    tensor([[-0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245,...5,
             -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245, -0.7245]],
           device='cuda:0')
)

    @staticmethod
    def ppo_loss_fn(
        log_probs,
        selected_token_ids,
        attention_mask,
        advantages,
        full_attention_mask,
        ref_per_token_logps=None,  # shape: [chunk_size, seq_len]
        old_per_token_logps=None,
        ref_log_probs=None,  # used when ref_per_token_logps is None (shape: [chunk_size, seq_len, vocab_size])
        epsilon_low=0.2,
        epsilon_high=0.2,
        beta=0.04,
        loss_type="bnpo",  # ["grpo", "bnpo", "dr_grpo"]
        max_completion_length=None,  # Required for dr_grpo
        **kwargs,
    ):
        """GRPO Loss Function matching GRPOTrainer implementation."""
        per_token_logps = log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
            -1
        )  # (batch_size, seq_len)
    
        # Get reference model probabilities
        if ref_per_token_logps is None:
            if ref_log_probs is not None:
                with torch.no_grad():
                    ref_per_token_logps = ref_log_probs.gather(dim=-1, index=selected_token_ids.unsqueeze(-1)).squeeze(
                        -1
                    )
            else:
                ref_per_token_logps = per_token_logps.detach()
    
        # Compute policy gradient loss with importance sampling ratio
        old_per_token_logps = old_per_token_logps if old_per_token_logps is not None else per_token_logps.detach()
        coef_1 = torch.exp(per_token_logps - old_per_token_logps)
        coef_2 = clip_coef_fn(coef_1, epsilon_low, epsilon_high)
        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        if beta != 0.0:
            # Compute KL penalty (approximates KL[per_token_logps, ref_per_token_logps])
            kl_div = k3_loss_fn(ref_per_token_logps, per_token_logps)
            # Combine losses
            per_token_loss = per_token_loss + beta * kl_div
    
        # Note: We normalize by the number of tokens in the batch (using full_attention_mask),
        # which is consistent with the DAPO loss implementation (https://arxiv.org/html/2503.14476v1)
        # and TRL GRPO implementation
        # (https://github.com/huggingface/trl/blob/e751a16df56e70190fb94bed4a2035eec3303777/trl/trainer/grpo_trainer.py#L966)
        if loss_type == "grpo":
            # Average per-sequence loss
            loss = (
                (per_token_loss * attention_mask).sum(-1) / torch.clamp(attention_mask.sum(-1), min=1.0)
            ).sum() / full_attention_mask.shape[0]
        elif loss_type == "bnpo":
            # Batch Normalized Per-token loss (original implementation)
            loss = (per_token_loss * attention_mask).sum() / torch.clamp(full_attention_mask.sum(), min=1.0)
        elif loss_type == "dr_grpo":
            # Dimension-Reduced GRPO (normalize by batch_size * max_completion_length)
            if max_completion_length is None:
                raise ValueError("max_completion_length must be provided for loss_type 'dr_grpo'")
            loss = (per_token_loss * attention_mask).sum() / (full_attention_mask.shape[0] * max_completion_length)
        else:
>           raise ValueError(f"Unknown loss type: {loss_type}")
E           ValueError: Unknown loss type: dapo

.venv/lib/python3.11/site-packages/liger_kernel/chunked_loss/grpo_loss.py:82: ValueError

Metadata

Metadata

Labels

⚡ PEFTRelated to PEFT🏋 GRPORelated to GRPO🐛 bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions