We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 66570b1 commit 02bac05Copy full SHA for 02bac05
src/liger_kernel/transformers/model/gemma3.py
@@ -255,7 +255,7 @@ def multimodal_forward(
255
shift_labels = shift_labels.view(-1).to(hidden_device)
256
257
lce = LigerFusedLinearCrossEntropyLoss()
258
- loss = lce(self.language_model.lm_head.weight, shift_hidden_states, shift_labels)
+ loss = lce(self.lm_head.weight, shift_hidden_states, shift_labels)
259
else:
260
logits = self.lm_head(kept_hidden_states)
261
if labels is not None:
0 commit comments