Skip to content

Commit 02bac05

Browse files
authored
Fix gemma3 forward with skip_logits (#795)
Seems like this one was overlooked in #787
1 parent 66570b1 commit 02bac05

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/liger_kernel/transformers/model/gemma3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def multimodal_forward(
255255
shift_labels = shift_labels.view(-1).to(hidden_device)
256256

257257
lce = LigerFusedLinearCrossEntropyLoss()
258-
loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
258+
loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
259259
else:
260260
logits = self.lm_head(kept_hidden_states)
261261
if labels is not None:

0 commit comments

Comments
 (0)