Skip to content

Commit 983b674

Browse files
authored
[Qwen3]: If qwen3 is used along with peft config, peft adds opcl obj no… (#926)
…t injested further ## Summary Fixes : #925 Fix TypeError: liger_fused_linear_cross_entropy() got an unexpected keyword argument 'return_dict' that occurs when using Liger Kernel with PEFT and transformers Trainer. The return_dict parameter is a standard transformers parameter that controls output format (ModelOutput vs tuple). When using PEFT with Liger Kernel models, this parameter is passed through **kwargs all the way to liger_fused_linear_cross_entropy() which doesn't accept it, causing training to crash. This PR adds kwargs.pop("return_dict", None) in all affected model files to remove the parameter before it reaches the loss calculation functions. <!--- ## Details Root Cause: - transformers Trainer passes return_dict in model inputs - PEFT wrapper forwards all kwargs to base model - Liger Kernel model implementations pass **kwargs to LigerForCausalLMLoss() - This propagates to liger_fused_linear_cross_entropy() which doesn't accept return_dict ---> ## Testing Done - Verified the fix resolves the TypeError with Qwen3 + PEFT + transformers Trainer - Tested training runs successfully complete without the error <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: H100 * 8 GPUs - [] run make test to ensure correctness - [] run make checkstyle to ensure code style - [] run make test-convergence to ensure convergence
1 parent 4c32ab6 commit 983b674

File tree

1 file changed

+2
-0
lines changed
  • src/liger_kernel/transformers/model

1 file changed

+2
-0
lines changed

src/liger_kernel/transformers/model/qwen3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def lce_forward(
8383
kept_hidden_states = hidden_states[:, slice_indices, :]
8484

8585
shift_labels = kwargs.pop("shift_labels", None)
86+
# Remove output-control parameters that shouldn't be passed to loss functions
87+
kwargs.pop("return_dict", None)
8688
logits = None
8789
loss = None
8890
token_accuracy = None

0 commit comments

Comments
 (0)