diff --git a/src/liger_kernel/transformers/model/loss_utils.py b/src/liger_kernel/transformers/model/loss_utils.py index 279fa160d..8436c0ba5 100644 --- a/src/liger_kernel/transformers/model/loss_utils.py +++ b/src/liger_kernel/transformers/model/loss_utils.py @@ -23,6 +23,7 @@ def fixed_fused_linear_cross_entropy( reduction=reduction, ignore_index=ignore_index, softcap=final_logit_softcapping, + **kwargs, ) if reduction == "sum": loss = loss / num_items_in_batch