Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions src/liger_kernel/chunked_loss/cpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
def preference_loss_fn(
chosen_logps_chunk, rejected_logps_chunk, full_target, beta=0.1
):
"""
Paper: https://arxiv.org/pdf/2401.08417

Expand All @@ -26,14 +28,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1):
- D: Dataset of preferences

Args:
chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,).
rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor
beta (float): Weight for the CPO loss
chosen_logps_chunk (torch.Tensor): Avg log probabilities of chosen tokens in the chunk. Shape: (batch_size,).
rejected_logps_chunk (torch.Tensor): Avg log probabilities of rejected tokens in the chunk. Shape: (batch_size,).
full_target (torch.Tensor): Non chunked full target tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder is it full_target or actually target_chunk?

From the fused function, we are feeding into target_chunk

        def fused_fwd_bwd(
            input_chunk, target_chunk, ref_input_chunk, preference_labels_chunk
        ):
            """
            Fused forward and backward pass for a chunk of input and target.
            """
            if bias is not None:
                return torch.func.grad_and_value(
                    compute_loss, argnums=(0, 1, 3), has_aux=True
                )(
                    input_chunk,
                    weight,
                    target_chunk,
                    bias,
                    ref_input_chunk=ref_input_chunk,
                    preference_labels=preference_labels_chunk,
                )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel like it should be the full target as we use it to normalize and then sum up for all chunks but seems we're feeding the target_chunk instead?

beta (float): Weight for the CPO loss.
"""
logits = beta * (chosen_logps - rejected_logps)
loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2)
return loss
logits_chunk = beta * (chosen_logps_chunk - rejected_logps_chunk)
loss_chunk = F.logsigmoid(logits_chunk).sum() / (full_target.shape[0] // 2)
return loss_chunk

@staticmethod
def forward(
Expand Down
38 changes: 21 additions & 17 deletions src/liger_kernel/chunked_loss/dpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase):

@staticmethod
def preference_loss_fn(
chosen_logps,
rejected_logps,
chosen_logps_chunk,
rejected_logps_chunk,
full_target,
ref_chosen_logps=None,
ref_rejected_logps=None,
ref_chosen_logps_chunk=None,
ref_rejected_logps_chunk=None,
beta=0.1,
):
"""
Expand All @@ -32,25 +32,29 @@ def preference_loss_fn(
- E: Expected value over the dataset

Args:
chosen_logps: Log probabilities of chosen tokens (batch_size,)
rejected_logps: Log probabilities of rejected tokens (batch_size,)
chosen_logps_chunk: Log probabilities of chosen tokens in the chunk (batch_size,)
rejected_logps_chunk: Log probabilities of rejected tokens in the chunk (batch_size,)
full_target: Non chunked full target tensor
ref_chosen_logps: Reference log probs of chosen tokens (batch_size,)
ref_rejected_logps: Reference log probs of rejected tokens (batch_size,)
ref_chosen_logps_chunk: Reference log probs of chosen tokens in the chunk (batch_size,)
ref_rejected_logps_chunk: Reference log probs of rejected tokens in the chunk (batch_size,)
beta: Weight for the direct preference loss
"""

if ref_chosen_logps is None:
ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device)
if ref_rejected_logps is None:
ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device)
if ref_chosen_logps_chunk is None:
ref_chosen_logps_chunk = torch.tensor(0.0, device=chosen_logps_chunk.device)
if ref_rejected_logps_chunk is None:
ref_rejected_logps_chunk = torch.tensor(
0.0, device=rejected_logps_chunk.device
)

chosen_logratios = chosen_logps - ref_chosen_logps
rejected_logratios = rejected_logps - ref_rejected_logps
chosen_logratios_chunk = chosen_logps_chunk - ref_chosen_logps_chunk
rejected_logratios_chunk = rejected_logps_chunk - ref_rejected_logps_chunk

logits_diff = beta * (chosen_logratios - rejected_logratios)
loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2)
return loss
logits_diff_chunk = beta * (chosen_logratios_chunk - rejected_logratios_chunk)
loss_chunk = -F.logsigmoid(logits_diff_chunk).sum() / (
full_target.shape[0] // 2
)
return loss_chunk

@staticmethod
def forward(
Expand Down
Loading
Loading