diff --git a/src/liger_kernel/chunked_loss/cpo_loss.py b/src/liger_kernel/chunked_loss/cpo_loss.py index 2b8052e25..b0ecd7c33 100644 --- a/src/liger_kernel/chunked_loss/cpo_loss.py +++ b/src/liger_kernel/chunked_loss/cpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearCPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps_chunk, rejected_logps_chunk, full_target, beta=0.1 + ): """ Paper: https://arxiv.org/pdf/2401.08417 @@ -26,14 +28,14 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): - D: Dataset of preferences Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - full_target (torch.Tensor): Non chunked full target tensor - beta (float): Weight for the CPO loss + chosen_logps_chunk (torch.Tensor): Avg log probabilities of chosen tokens in the chunk. Shape: (batch_size,). + rejected_logps_chunk (torch.Tensor): Avg log probabilities of rejected tokens in the chunk. Shape: (batch_size,). + full_target (torch.Tensor): Non chunked full target tensor. + beta (float): Weight for the CPO loss. """ - logits = beta * (chosen_logps - rejected_logps) - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) - return loss + logits_chunk = beta * (chosen_logps_chunk - rejected_logps_chunk) + loss_chunk = F.logsigmoid(logits_chunk).sum() / (full_target.shape[0] // 2) + return loss_chunk @staticmethod def forward( diff --git a/src/liger_kernel/chunked_loss/dpo_loss.py b/src/liger_kernel/chunked_loss/dpo_loss.py index cf07e186e..8eee230fe 100644 --- a/src/liger_kernel/chunked_loss/dpo_loss.py +++ b/src/liger_kernel/chunked_loss/dpo_loss.py @@ -10,11 +10,11 @@ class LigerFusedLinearDPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, - rejected_logps, + chosen_logps_chunk, + rejected_logps_chunk, full_target, - ref_chosen_logps=None, - ref_rejected_logps=None, + ref_chosen_logps_chunk=None, + ref_rejected_logps_chunk=None, beta=0.1, ): """ @@ -32,25 +32,29 @@ def preference_loss_fn( - E: Expected value over the dataset Args: - chosen_logps: Log probabilities of chosen tokens (batch_size,) - rejected_logps: Log probabilities of rejected tokens (batch_size,) + chosen_logps_chunk: Log probabilities of chosen tokens in the chunk (batch_size,) + rejected_logps_chunk: Log probabilities of rejected tokens in the chunk (batch_size,) full_target: Non chunked full target tensor - ref_chosen_logps: Reference log probs of chosen tokens (batch_size,) - ref_rejected_logps: Reference log probs of rejected tokens (batch_size,) + ref_chosen_logps_chunk: Reference log probs of chosen tokens in the chunk (batch_size,) + ref_rejected_logps_chunk: Reference log probs of rejected tokens in the chunk (batch_size,) beta: Weight for the direct preference loss """ - if ref_chosen_logps is None: - ref_chosen_logps = torch.tensor(0.0, device=chosen_logps.device) - if ref_rejected_logps is None: - ref_rejected_logps = torch.tensor(0.0, device=rejected_logps.device) + if ref_chosen_logps_chunk is None: + ref_chosen_logps_chunk = torch.tensor(0.0, device=chosen_logps_chunk.device) + if ref_rejected_logps_chunk is None: + ref_rejected_logps_chunk = torch.tensor( + 0.0, device=rejected_logps_chunk.device + ) - chosen_logratios = chosen_logps - ref_chosen_logps - rejected_logratios = rejected_logps - ref_rejected_logps + chosen_logratios_chunk = chosen_logps_chunk - ref_chosen_logps_chunk + rejected_logratios_chunk = rejected_logps_chunk - ref_rejected_logps_chunk - logits_diff = beta * (chosen_logratios - rejected_logratios) - loss = -F.logsigmoid(logits_diff).sum() / (full_target.shape[0] // 2) - return loss + logits_diff_chunk = beta * (chosen_logratios_chunk - rejected_logratios_chunk) + loss_chunk = -F.logsigmoid(logits_diff_chunk).sum() / ( + full_target.shape[0] // 2 + ) + return loss_chunk @staticmethod def forward( diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index 10e726055..aaa571230 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -2,116 +2,24 @@ from functools import partial import torch -from torch.nn import functional as F class LigerFusedLinearDistillationBase(torch.autograd.Function): @abstractmethod - def distillation_loss_fn(student_logits, teacher_logits, temperature): - """ - Compute distillation loss. - Args: - student_logits (torch.Tensor): Raw logits of student tokens. Shape: (batch_size * seq_len, vocab_size). - teacher_logits (torch.Tensor): Raw logits of teacher tokens. Shape: (batch_size * seq_len, vocab_size). - """ - raise NotImplementedError("Distillation loss function must be implemented.") - - @staticmethod - def chunk_forward( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=None, - teacher_bias=None, - ignore_index=-100, - compute_ce_loss=True, - ): - # Student - student_logits_chunk = student_input_chunk @ student_weight.t() - if student_bias is not None: - student_logits_chunk += student_bias - student_log_probs_chunk = F.log_softmax(student_logits_chunk.float(), dim=-1) - - # Teacher - with torch.no_grad(): - teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() - if teacher_bias is not None: - teacher_logits_chunk += teacher_bias - - # The hard/task loss - ce_loss = 0.0 - if compute_ce_loss: - ce_loss = F.nll_loss( - student_log_probs_chunk.view(-1, student_log_probs_chunk.shape[-1]), - target_chunk.view(-1), - reduction="sum", - ignore_index=ignore_index, - ) - - return student_logits_chunk, teacher_logits_chunk, ce_loss - - @staticmethod - def _compute_loss( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=None, - teacher_bias=None, - distillation_loss_fn=None, - full_target=None, - ignore_index=-100, - temperature=1.0, - weight_hard_loss=0.5, - weight_soft_loss=0.5, - compute_ce_loss=True, - **loss_kwargs, + def distillation_loss_fn( + student_logits_chunk, teacher_logits_chunk, target_chunk, full_target, **kwargs ): """ - Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Compute distillation loss. Args: - distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. - student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). - student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). - teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). - teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + student_logits_chunk (torch.Tensor): Chunk of student logits tensor. Shape: (chunk_size, vocab_size). + teacher_logits_chunk (torch.Tensor): Chunk of teacher logits tensor. Shape: (chunk_size, vocab_size). target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). - student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). - teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). - ignore_index (int): Index to ignore for loss computation. - weight_hard_loss (float): Weight for hard loss. - weight_soft_loss (float): Weight for soft loss. - compute_ce_loss (bool): Whether to compute CE loss. - loss_kwargs (dict): Additional arguments for the loss function. + kwargs: Additional arguments for the loss function. """ - student_logits_chunk, teacher_logits_chunk, hard_loss = ( - LigerFusedLinearDistillationBase.chunk_forward( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias=student_bias, - teacher_bias=teacher_bias, - ignore_index=ignore_index, - compute_ce_loss=compute_ce_loss, - ) - ) - - hard_loss /= full_target.shape[0] - - soft_loss = distillation_loss_fn( - student_logits_chunk, teacher_logits_chunk, temperature - ) - soft_loss /= full_target.shape[0] - - loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss - return loss, (soft_loss, hard_loss, student_logits_chunk, teacher_logits_chunk) + raise NotImplementedError("Distillation loss function must be implemented.") @staticmethod def forward( @@ -125,11 +33,6 @@ def forward( teacher_bias=None, loss_fn=None, chunk_size=1024, - ignore_index=-100, - weight_hard_loss=0.5, - weight_soft_loss=0.5, - compute_ce_loss=True, - temperature=1.0, compiled=True, **loss_kwargs, ): @@ -147,10 +50,6 @@ def forward( teacher_bias (torch.Tensor, optional): Teacher bias tensor. Shape: (vocab_size,). loss_fn (callable): Loss function to compute the loss on a chunk of input/target. chunk_size (int): Size of a chunk. - compute_ce_loss (bool): Whether to compute CE loss. - ignore_index (int): Index to ignore for loss computation. - weight_hard_loss (float): Weight for hard/task loss. - weight_soft_loss (float): Weight for soft/distillation loss. compiled (bool): Whether to use torch compile for chunk accumulation. loss_kwargs (dict): Other possible arguments that a loss function might need """ @@ -160,66 +59,48 @@ def forward( grad_bias = torch.zeros_like(student_bias) if student_bias is not None else None loss_acc = torch.zeros((), device=student_input.device) - loss_func_to_call = partial( + compute_loss = partial( LigerFusedLinearDistillationBase._compute_loss, distillation_loss_fn=loss_fn, full_target=target, - ignore_index=ignore_index, - weight_hard_loss=weight_hard_loss, - weight_soft_loss=weight_soft_loss, - compute_ce_loss=compute_ce_loss, - temperature=temperature, **loss_kwargs, ) + def fused_fwd_bwd(student_input_chunk, teacher_input_chunk, target_chunk): + """ + Fused forward and backward pass for a chunk of student input, teacher input and target. + """ + argnums = (0, 1, 5) if student_bias is not None else (0, 1) + return torch.func.grad_and_value( + compute_loss, argnums=argnums, has_aux=True + )( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias, + teacher_bias, + ) + def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): + (grad_input_chunk, grad_weight_chunk, *grad_bias_chunk), ( + chunk_loss, + ( + student_logits_chunk, + teacher_logits_chunk, + ), + ) = fused_fwd_bwd(student_input_chunk, teacher_input_chunk, target_chunk) + if student_bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, - ( - chunk_soft_loss, - chunk_hard_loss, - chunk_student_logits, - chunk_teacher_logits, - ), - ) = torch.func.grad_and_value( - loss_func_to_call, argnums=(0, 1, 5), has_aux=True - )( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias, - teacher_bias, - ) - grad_bias.add_(chunk_grad_bias) - else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, - ( - chunk_soft_loss, - chunk_hard_loss, - chunk_student_logits, - chunk_teacher_logits, - ), - ) = torch.func.grad_and_value( - loss_func_to_call, argnums=(0, 1), has_aux=True - )( - student_input_chunk, - student_weight, - teacher_input_chunk, - teacher_weight, - target_chunk, - student_bias, - teacher_bias, - ) - grad_weight.add_(chunk_grad_weight) + grad_bias.add_(grad_bias_chunk) + + grad_weight.add_(grad_weight_chunk) loss_acc.add_(chunk_loss) - return chunk_grad_input + grad_inputs.append(grad_input_chunk) if compiled: - accumulate_chunk = torch.compile(accumulate_chunk) + fused_fwd_bwd = torch.compile(fused_fwd_bwd) num_chunks = max(1, student_input.shape[0] // CHUNK_SIZE) _student_input_chunks = torch.chunk(student_input, chunks=num_chunks, dim=0) @@ -229,10 +110,7 @@ def accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk): for student_input_chunk, teacher_input_chunk, target_chunk in zip( _student_input_chunks, _teacher_input_chunks, _target_chunks ): - grad_input = accumulate_chunk( - student_input_chunk, teacher_input_chunk, target_chunk - ) - grad_inputs.append(grad_input) + accumulate_chunk(student_input_chunk, teacher_input_chunk, target_chunk) ctx.save_for_backward( torch.cat(grad_inputs, dim=0), @@ -250,3 +128,51 @@ def backward(ctx, grad_output): grad_bias = grad_bias * grad_output if grad_bias is not None else None return grad_input, grad_weight, None, grad_bias + + @staticmethod + def _compute_loss( + student_input_chunk, + student_weight, + teacher_input_chunk, + teacher_weight, + target_chunk, + student_bias=None, + teacher_bias=None, + distillation_loss_fn=None, + full_target=None, + **loss_kwargs, + ): + """ + Compute the total loss for a chunk of input and target, while using an knowleedge distillation loss function. + Args: + student_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, student_hidden_size). + student_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, student_hidden_size). + teacher_input_chunk (torch.Tensor): Chunk of input tensor. Shape: (chunk_size, teacher_hidden_size). + teacher_weight (torch.Tensor): Weight tensor. Shape: (vocab_size, teacher_hidden_size). + target_chunk (torch.Tensor): Chunk of target tensor. Shape: (chunk_size,). + student_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + teacher_bias (torch.Tensor, optional): Bias tensor. Shape: (vocab_size,). + distillation_loss_fn (callable): Loss function to compute the loss on a chunk of input/target. + full_target (torch.Tensor): Full target tensor. Shape: (chunk_size,). + loss_kwargs (dict): Additional arguments for the loss function. + """ + # Student + student_logits_chunk = student_input_chunk @ student_weight.t() + if student_bias is not None: + student_logits_chunk += student_bias + + # Teacher + with torch.no_grad(): + teacher_logits_chunk = teacher_input_chunk @ teacher_weight.t() + if teacher_bias is not None: + teacher_logits_chunk += teacher_bias + + loss_chunk = distillation_loss_fn( + student_logits_chunk, + teacher_logits_chunk, + target_chunk, + full_target, + **loss_kwargs, + ) + + return loss_chunk, (student_logits_chunk, teacher_logits_chunk) diff --git a/src/liger_kernel/chunked_loss/fused_linear_preference.py b/src/liger_kernel/chunked_loss/fused_linear_preference.py index fff0791ec..3843dc1db 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_preference.py +++ b/src/liger_kernel/chunked_loss/fused_linear_preference.py @@ -101,82 +101,68 @@ def fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk): """ Fused forward and backward pass for a chunk of input and target. """ - if bias is not None: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1, 3), has_aux=True - )( - input_chunk, - weight, - target_chunk, - bias, - ref_input_chunk=ref_input_chunk, - ) - else: - return torch.func.grad_and_value( - compute_loss, argnums=(0, 1), has_aux=True - )(input_chunk, weight, target_chunk, ref_input_chunk=ref_input_chunk) + argnums = (0, 1, 3) if bias is not None else (0, 1) + + return torch.func.grad_and_value( + compute_loss, argnums=argnums, has_aux=True + )( + input_chunk, + weight, + target_chunk, + bias, + ref_input_chunk=ref_input_chunk, + ) def accumulate_chunk(input_chunk, target_chunk, ref_input_chunk=None): + (grad_input_chunk, grad_weight_chunk, *grad_bias_chunk), ( + loss_chunk, + ( + chosen_logps_chunk, + rejected_logps_chunk, + chosen_logits_mean_chunk, + rejected_logits_mean_chunk, + nll_loss_chunk, + *aux_outputs_chunk, + ), + ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) + if bias is not None: - (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), ( - chunk_loss, - ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_nll_loss, - *aux_outputs, - ), - ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) - grad_bias.add_(chunk_grad_bias) # accumulate bias gradient - else: - (chunk_grad_input, chunk_grad_weight), ( - chunk_loss, - ( - chunk_chosen_logps, - chunk_rejected_logps, - chunk_chosen_logits_mean, - chunk_rejected_logits_mean, - chunk_nll_loss, - *aux_outputs, - ), - ) = fused_fwd_bwd(input_chunk, target_chunk, ref_input_chunk) + grad_bias.add_(grad_bias_chunk[0]) # accumulate bias gradient # Accumulate gradients - grad_weight.add_(chunk_grad_weight) - grad_chosen_inputs.append(chunk_grad_input[: chosen_target_chunk.shape[0]]) + grad_weight.add_(grad_weight_chunk) + grad_chosen_inputs.append(grad_input_chunk[: chosen_target_chunk.shape[0]]) grad_rejected_inputs.append( - chunk_grad_input[chosen_target_chunk.shape[0] :] + grad_input_chunk[chosen_target_chunk.shape[0] :] ) # Accumulate loss - loss_acc.add_(chunk_loss) + loss_acc.add_(loss_chunk) # Accumulate metrics - policy_chosen_logps.append(chunk_chosen_logps) - policy_rejected_logps.append(chunk_rejected_logps) - policy_chosen_logits_mean.add_(chunk_chosen_logits_mean) - policy_rejected_logits_mean.add_(chunk_rejected_logits_mean) - policy_nll_loss.add_(chunk_nll_loss) + policy_chosen_logps.append(chosen_logps_chunk) + policy_rejected_logps.append(rejected_logps_chunk) + policy_chosen_logits_mean.add_(chosen_logits_mean_chunk) + policy_rejected_logits_mean.add_(rejected_logits_mean_chunk) + policy_nll_loss.add_(nll_loss_chunk) # aux_outputs # Initialize storage for aux_outputs if len(aggregated_aux_outputs) == 0: - for aux in aux_outputs: - if aux.ndim == 0: + for aux_chunk in aux_outputs_chunk: + if aux_chunk.ndim == 0: aggregated_aux_outputs.append( - torch.zeros((), device=aux.device) + torch.zeros((), device=aux_chunk.device) ) else: aggregated_aux_outputs.append([]) # Process each aux_output - for i, aux in enumerate(aux_outputs): - if aux.ndim == 0: - aggregated_aux_outputs[i].add_(aux) + for i, aux_chunk in enumerate(aux_outputs_chunk): + if aux_chunk.ndim == 0: + aggregated_aux_outputs[i].add_(aux_chunk) else: - aggregated_aux_outputs[i].append(aux) + aggregated_aux_outputs[i].append(aux_chunk) if compiled: fused_fwd_bwd = torch.compile(fused_fwd_bwd) @@ -289,35 +275,37 @@ def chunk_forward( logits_chunk = logits_chunk + bias log_probs_chunk = F.log_softmax(logits_chunk.float(), dim=-1) - chosen_nll_loss = 0.0 + chosen_nll_loss_chunk = 0.0 if compute_nll_loss: - chosen_nll_loss = F.nll_loss( + chosen_nll_loss_chunk = F.nll_loss( log_probs_chunk[:len_chosen_chunk].view(-1, log_probs_chunk.shape[-1]), target_chunk[:len_chosen_chunk].view(-1), reduction="sum", ignore_index=ignore_index, ) - loss_mask = target_chunk != ignore_index - label_chunk = torch.where(loss_mask, target_chunk, 0) + loss_mask_chunk = target_chunk != ignore_index + label_chunk = torch.where(loss_mask_chunk, target_chunk, 0) - per_token_logps = log_probs_chunk.gather(-1, label_chunk.unsqueeze(-1)).squeeze( + per_token_logps_chunk = log_probs_chunk.gather( + -1, label_chunk.unsqueeze(-1) + ).squeeze(-1) + average_log_prob_chunk = (per_token_logps_chunk * loss_mask_chunk).sum( -1 - ) - average_log_prob = (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1) + ) / loss_mask_chunk.sum(-1) - chosen_logps = average_log_prob[:len_chosen_chunk] - rejected_logps = average_log_prob[len_chosen_chunk:] + chosen_logps_chunk = average_log_prob_chunk[:len_chosen_chunk] + rejected_logps_chunk = average_log_prob_chunk[len_chosen_chunk:] - chosen_logits = logits_chunk[:len_chosen_chunk] - rejected_logits = logits_chunk[len_chosen_chunk:] + chosen_logits_chunk = logits_chunk[:len_chosen_chunk] + rejected_logits_chunk = logits_chunk[len_chosen_chunk:] return ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, + chosen_logps_chunk, + rejected_logps_chunk, + chosen_logits_chunk, + rejected_logits_chunk, + chosen_nll_loss_chunk, ) @staticmethod @@ -357,11 +345,11 @@ def _compute_loss( loss_kwargs (dict): Additional arguments for the loss function. """ ( - chosen_logps, - rejected_logps, - chosen_logits, - rejected_logits, - chosen_nll_loss, + chosen_logps_chunk, + rejected_logps_chunk, + chosen_logits_chunk, + rejected_logits_chunk, + chosen_nll_loss_chunk, ) = LigerFusedLinearPreferenceBase.chunk_forward( input_chunk, weight, @@ -370,25 +358,25 @@ def _compute_loss( ignore_index=ignore_index, compute_nll_loss=compute_nll_loss, ) - chosen_nll_loss = ( - chosen_nll_loss + chosen_nll_loss_chunk = ( + chosen_nll_loss_chunk / (full_target[: full_target.shape[0] // 2] != ignore_index).sum() ) - chosen_logits_mean = chosen_logits.sum() / ( + chosen_logits_mean_chunk = chosen_logits_chunk.sum() / ( full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] ) - rejected_logits_mean = rejected_logits.sum() / ( + rejected_logits_mean_chunk = rejected_logits_chunk.sum() / ( full_target.shape[0] // 2 * input_chunk.shape[1] * weight.shape[0] ) if use_ref_model: with torch.no_grad(): ( - ref_chosen_logps, - ref_rejected_logps, - ref_chosen_logits, - ref_rejected_logits, - ref_chosen_nll_loss, + ref_chosen_logps_chunk, + ref_rejected_logps_chunk, + _, + _, + _, ) = LigerFusedLinearPreferenceBase.chunk_forward( ref_input_chunk, ref_weight, @@ -397,23 +385,27 @@ def _compute_loss( ignore_index=ignore_index, compute_nll_loss=False, # We don't need NLL loss for the reference model ) - loss_kwargs["ref_chosen_logps"] = ref_chosen_logps - loss_kwargs["ref_rejected_logps"] = ref_rejected_logps + loss_kwargs["ref_chosen_logps_chunk"] = ref_chosen_logps_chunk + loss_kwargs["ref_rejected_logps_chunk"] = ref_rejected_logps_chunk preference_loss_outputs = preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=beta, **loss_kwargs + chosen_logps_chunk, + rejected_logps_chunk, + full_target, + beta=beta, + **loss_kwargs, ) if isinstance(preference_loss_outputs, tuple): - preference_loss, *aux_outputs = preference_loss_outputs + preference_loss_chunk, *aux_outputs_chunk = preference_loss_outputs else: - preference_loss, aux_outputs = preference_loss_outputs, [] - - loss = alpha * chosen_nll_loss - preference_loss - return_vars = ( - chosen_logps, - rejected_logps, - chosen_logits_mean, - rejected_logits_mean, - chosen_nll_loss, + preference_loss_chunk, aux_outputs_chunk = preference_loss_outputs, [] + + loss_chunk = alpha * chosen_nll_loss_chunk - preference_loss_chunk + return_vars_chunk = ( + chosen_logps_chunk, + rejected_logps_chunk, + chosen_logits_mean_chunk, + rejected_logits_mean_chunk, + chosen_nll_loss_chunk, ) - return loss, (*return_vars, *aux_outputs) + return loss_chunk, (*return_vars_chunk, *aux_outputs_chunk) diff --git a/src/liger_kernel/chunked_loss/orpo_loss.py b/src/liger_kernel/chunked_loss/orpo_loss.py index c860d4bd9..c8d03bb3e 100644 --- a/src/liger_kernel/chunked_loss/orpo_loss.py +++ b/src/liger_kernel/chunked_loss/orpo_loss.py @@ -9,7 +9,9 @@ class LigerFusedLinearORPOFunction(LigerFusedLinearPreferenceBase): @staticmethod - def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): + def preference_loss_fn( + chosen_logps_chunk, rejected_logps_chunk, full_target, beta=0.1 + ): """ Paper: https://arxiv.org/pdf/2403.07691 @@ -26,25 +28,31 @@ def preference_loss_fn(chosen_logps, rejected_logps, full_target, beta=0.1): - odds_θ: Odds function for the policy Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). - full_target (torch.Tensor): Non chunked full target tensor + chosen_logps_chunk (torch.Tensor): Avg log probabilities of chosen tokens in the chunk. Shape: (batch_size,). + rejected_logps_chunk (torch.Tensor): Avg log probabilities of rejected tokens in the chunk. Shape: (batch_size,). + full_target (torch.Tensor): Non chunked full target tensor. beta (float): Weight for the odds ratio loss. """ - log_odds = (chosen_logps - rejected_logps) - ( - torch.log1p(-torch.exp(chosen_logps)) - - torch.log1p(-torch.exp(rejected_logps)) + log_odds_chunk = (chosen_logps_chunk - rejected_logps_chunk) - ( + torch.log1p(-torch.exp(chosen_logps_chunk)) + - torch.log1p(-torch.exp(rejected_logps_chunk)) ) - ratio = F.logsigmoid(log_odds) - loss = beta * ratio.sum() / (full_target.shape[0] // 2) + ratio_chunk = F.logsigmoid(log_odds_chunk) + loss_chunk = beta * ratio_chunk.sum() / (full_target.shape[0] // 2) - chosen_rewards = beta * chosen_logps - rejected_rewards = beta * rejected_logps + chosen_rewards_chunk = beta * chosen_logps_chunk + rejected_rewards_chunk = beta * rejected_logps_chunk - log_odds_ratio = torch.sum(ratio) / (full_target.shape[0] // 2) - log_odds_chosen = torch.sum(log_odds) / (full_target.shape[0] // 2) + log_odds_ratio_chunk = torch.sum(ratio_chunk) / (full_target.shape[0] // 2) + log_odds_chosen_chunk = torch.sum(log_odds_chunk) / (full_target.shape[0] // 2) - return loss, chosen_rewards, rejected_rewards, log_odds_ratio, log_odds_chosen + return ( + loss_chunk, + chosen_rewards_chunk, + rejected_rewards_chunk, + log_odds_ratio_chunk, + log_odds_chosen_chunk, + ) @staticmethod def forward( diff --git a/src/liger_kernel/chunked_loss/simpo_loss.py b/src/liger_kernel/chunked_loss/simpo_loss.py index 7efa0603d..fe2253f8f 100644 --- a/src/liger_kernel/chunked_loss/simpo_loss.py +++ b/src/liger_kernel/chunked_loss/simpo_loss.py @@ -10,7 +10,7 @@ class LigerFusedLinearSimPOFunction(LigerFusedLinearPreferenceBase): @staticmethod def preference_loss_fn( - chosen_logps, rejected_logps, full_target, beta=0.1, gamma=0.5 + chosen_logps_chunk, rejected_logps_chunk, full_target, beta=0.1, gamma=0.5 ): """ Paper: https://arxiv.org/pdf/2405.14734 @@ -28,15 +28,15 @@ def preference_loss_fn( - γ: gemma margin term Args: - chosen_logps (torch.Tensor): Avg log probabilities of chosen tokens. Shape: (batch_size,). - rejected_logps (torch.Tensor): Avg log probabilities of rejected tokens. Shape: (batch_size,). + chosen_logps_chunk (torch.Tensor): Avg log probabilities of chosen tokens in the chunk. Shape: (batch_size,). + rejected_logps_chunk (torch.Tensor): Avg log probabilities of rejected tokens in the chunk. Shape: (batch_size,). full_target: Non chunked full target tensor beta (float): beta weight gamma (float): gemma margin term """ - logits = beta * (chosen_logps - rejected_logps) - gamma - loss = F.logsigmoid(logits).sum() / (full_target.shape[0] // 2) - return loss + logits_chunk = beta * (chosen_logps_chunk - rejected_logps_chunk) - gamma + loss_chunk = F.logsigmoid(logits_chunk).sum() / (full_target.shape[0] // 2) + return loss_chunk @staticmethod def forward(