Skip to content

Commit 88ad1a0

Browse files
authored
fix orpo chosen-nll loss (huggingface#2502)
1 parent 9908dda commit 88ad1a0

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

trl/trainer/orpo_trainer.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -774,18 +774,12 @@ def cross_entropy_loss(logits, labels):
774774
loss = loss_fct(logits, labels)
775775
return loss
776776

777-
if self.is_encoder_decoder:
778-
labels = concatenated_batch["concatenated_labels"].clone()
779-
else:
780-
labels = concatenated_batch["concatenated_input_ids"].clone()
781-
attention_mask = concatenated_batch["concatenated_attention_mask"]
782-
labels = torch.where(attention_mask == 1, labels, self.label_pad_token_id)
783-
777+
labels = concatenated_batch["concatenated_labels"].clone()
784778
chosen_nll_loss = cross_entropy_loss(all_logits[:len_chosen], labels[:len_chosen])
785779

786780
all_logps = self.get_batch_logps(
787781
all_logits,
788-
concatenated_batch["concatenated_labels"],
782+
labels,
789783
average_log_prob=True,
790784
is_encoder_decoder=self.is_encoder_decoder,
791785
label_pad_token_id=self.label_pad_token_id,
@@ -794,8 +788,12 @@ def cross_entropy_loss(logits, labels):
794788
chosen_logps = all_logps[:len_chosen]
795789
rejected_logps = all_logps[len_chosen:]
796790

797-
chosen_logits = all_logits[:len_chosen]
798-
rejected_logits = all_logits[len_chosen:]
791+
if not self.is_encoder_decoder:
792+
chosen_logits = all_logits[:len_chosen, :-1, :]
793+
rejected_logits = all_logits[len_chosen:, :-1, :]
794+
else:
795+
chosen_logits = all_logits[:len_chosen]
796+
rejected_logits = all_logits[len_chosen:]
799797

800798
if self.aux_loss_enabled:
801799
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_nll_loss, outputs.aux_loss)

0 commit comments

Comments
 (0)