@@ -774,18 +774,12 @@ def cross_entropy_loss(logits, labels):
774774 loss = loss_fct (logits , labels )
775775 return loss
776776
777- if self .is_encoder_decoder :
778- labels = concatenated_batch ["concatenated_labels" ].clone ()
779- else :
780- labels = concatenated_batch ["concatenated_input_ids" ].clone ()
781- attention_mask = concatenated_batch ["concatenated_attention_mask" ]
782- labels = torch .where (attention_mask == 1 , labels , self .label_pad_token_id )
783-
777+ labels = concatenated_batch ["concatenated_labels" ].clone ()
784778 chosen_nll_loss = cross_entropy_loss (all_logits [:len_chosen ], labels [:len_chosen ])
785779
786780 all_logps = self .get_batch_logps (
787781 all_logits ,
788- concatenated_batch [ "concatenated_labels" ] ,
782+ labels ,
789783 average_log_prob = True ,
790784 is_encoder_decoder = self .is_encoder_decoder ,
791785 label_pad_token_id = self .label_pad_token_id ,
@@ -794,8 +788,12 @@ def cross_entropy_loss(logits, labels):
794788 chosen_logps = all_logps [:len_chosen ]
795789 rejected_logps = all_logps [len_chosen :]
796790
797- chosen_logits = all_logits [:len_chosen ]
798- rejected_logits = all_logits [len_chosen :]
791+ if not self .is_encoder_decoder :
792+ chosen_logits = all_logits [:len_chosen , :- 1 , :]
793+ rejected_logits = all_logits [len_chosen :, :- 1 , :]
794+ else :
795+ chosen_logits = all_logits [:len_chosen ]
796+ rejected_logits = all_logits [len_chosen :]
799797
800798 if self .aux_loss_enabled :
801799 return (chosen_logps , rejected_logps , chosen_logits , rejected_logits , chosen_nll_loss , outputs .aux_loss )
0 commit comments