From 25eb714d27e059de16310a907e18e0b3581c362e Mon Sep 17 00:00:00 2001 From: yeshsurya Date: Sun, 2 Nov 2025 15:45:52 +0530 Subject: [PATCH 1/5] [Feat]: Adding support for gpt-oss --- src/liger_kernel/transformers/__init__.py | 3 + .../transformers/model/gpt_oss.py | 138 ++++++++++++++++++ src/liger_kernel/transformers/monkey_patch.py | 77 ++++++++++ test/convergence/bf16/test_mini_models.py | 48 ++++++ test/utils.py | 12 ++ 5 files changed, 278 insertions(+) create mode 100644 src/liger_kernel/transformers/model/gpt_oss.py diff --git a/src/liger_kernel/transformers/__init__.py b/src/liger_kernel/transformers/__init__.py index 54434d77c..c0cede1a7 100644 --- a/src/liger_kernel/transformers/__init__.py +++ b/src/liger_kernel/transformers/__init__.py @@ -39,6 +39,7 @@ from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401 + from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_internvl # noqa: F401 from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401 @@ -105,6 +106,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", "apply_liger_kernel_to_glm4v_moe", + "apply_liger_kernel_to_gpt_oss", "apply_liger_kernel_to_granite", "apply_liger_kernel_to_internvl", "apply_liger_kernel_to_llama", @@ -177,6 +179,7 @@ def __getattr__(name: str): "apply_liger_kernel_to_glm4", "apply_liger_kernel_to_glm4v", "apply_liger_kernel_to_glm4v_moe", + "apply_liger_kernel_to_gpt_oss", "apply_liger_kernel_to_granite", "apply_liger_kernel_to_internvl", "apply_liger_kernel_to_llama", diff --git a/src/liger_kernel/transformers/model/gpt_oss.py b/src/liger_kernel/transformers/model/gpt_oss.py new file mode 100644 index 000000000..de258e0b9 --- /dev/null +++ b/src/liger_kernel/transformers/model/gpt_oss.py @@ -0,0 +1,138 @@ +from typing import List +from typing import Optional +from typing import Union + +import torch + +from transformers.modeling_outputs import MoeCausalLMOutputWithPast +from transformers.modeling_outputs import MoeModelOutputWithPast +from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func + +from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss + + +def lce_forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + skip_logits: Optional[bool] = None, + **kwargs, +) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + logits_to_keep (`int` or `torch.Tensor`, *optional*): + If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension. + This is useful when using packed tensor format (single dimension for batch and sequence length). + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + kept_hidden_states = hidden_states[:, slice_indices, :] + + shift_labels = kwargs.pop("shift_labels", None) + logits = None + loss = None + + if skip_logits is None: + skip_logits = self.training and (labels is not None or shift_labels is not None) + + if skip_logits: + loss = LigerForCausalLMLoss( + hidden_states=kept_hidden_states, + lm_head_weight=self.lm_head.weight, + labels=labels, + shift_labels=shift_labels, + hidden_size=self.config.hidden_size, + **kwargs, + ) + else: # if in inference model materialize logits + logits = self.lm_head(kept_hidden_states) + if labels is not None or shift_labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + shift_labels=shift_labels, + vocab_size=self.vocab_size, + **kwargs, + ) + + aux_loss = None + if output_router_logits: + aux_loss = load_balancing_loss_func( + outputs.router_logits, + self.num_experts, + self.num_experts_per_tok, + attention_mask, + ) + if labels is not None: + loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device + + return MoeCausalLMOutputWithPast( + loss=loss, + aux_loss=aux_loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 940a0e222..2b217c5e9 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -21,6 +21,7 @@ from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected +from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward @@ -1465,6 +1466,81 @@ def apply_liger_kernel_to_qwen3_moe( _patch_rms_norm_module(decoder_layer.post_attention_layernorm) +def apply_liger_kernel_to_gpt_oss( + rope: bool = True, + cross_entropy: bool = False, + fused_linear_cross_entropy: bool = True, + rms_norm: bool = True, + swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation + model: PreTrainedModel = None, +) -> None: + """ + Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models. + NOTE: GPT-OSS is supported in transformers >= 4.55.0 + NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert + implementation with clamping and MXFP4 quantization. + + Args: + rope (bool): Whether to apply Liger's rotary position embedding. Default is True. + cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False. + fused_linear_cross_entropy (bool): + Whether to apply Liger's fused linear cross entropy loss. Default is True. + `cross_entropy` and `fused_linear_cross_entropy` cannot both be True. + If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. + rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True. + swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False. + Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default. + model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been + loaded. Default is None. + """ + if version.parse(transformers.__version__) < version.parse("4.55.0"): + logger.warning("GPT-OSS support requires transformers >= 4.55.0") + return + + assert not (cross_entropy and fused_linear_cross_entropy), ( + "cross_entropy and fused_linear_cross_entropy cannot both be True." + ) + + from transformers.models.gpt_oss import modeling_gpt_oss + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel + + from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward + + if rope: + modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb + + if rms_norm: + modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm + + if cross_entropy: + from transformers.loss.loss_utils import nn + + nn.functional.cross_entropy = liger_cross_entropy + + if fused_linear_cross_entropy: + if model is not None: + model.forward = MethodType(gpt_oss_lce_forward, model) + else: + modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward + + # Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation + # with clamping (swiglu_limit=7.0) and MXFP4 quantization + + if model is not None: + # The model instance already exists, so we need to additionally patch the + # instance variables that reference already-instantiated modules + + # get the base model from the model instance + base_model: GptOssModel = getattr(model, model.base_model_prefix, model) + + if rms_norm: + _patch_rms_norm_module(base_model.norm) + for decoder_layer in base_model.layers: + if rms_norm: + _patch_rms_norm_module(decoder_layer.input_layernorm) + _patch_rms_norm_module(decoder_layer.post_attention_layernorm) + + def apply_liger_kernel_to_qwen2_vl( rope: bool = True, cross_entropy: bool = False, @@ -2571,6 +2647,7 @@ def apply_liger_kernel_to_qwen3_next( "glm4": apply_liger_kernel_to_glm4, "glm4v": apply_liger_kernel_to_glm4v, "glm4v_moe": apply_liger_kernel_to_glm4v_moe, + "gpt_oss": apply_liger_kernel_to_gpt_oss, "internvl": apply_liger_kernel_to_internvl, "llama": apply_liger_kernel_to_llama, "llama4_text": apply_liger_kernel_to_llama4, diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index 1a2697b50..a038bfbdd 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -29,6 +29,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_glm4 from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_internvl from liger_kernel.transformers import apply_liger_kernel_to_llama @@ -61,6 +62,7 @@ from test.utils import revert_liger_kernel_to_glm4 from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_gpt_oss from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_internvl from test.utils import revert_liger_kernel_to_llama @@ -252,6 +254,15 @@ except ImportError: FALCONH1_AVAILABLE = False +try: + # GPT-OSS is only available in transformers>=4.55.0 + from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + GPT_OSS_AVAILABLE = True +except ImportError: + GPT_OSS_AVAILABLE = False + try: # Qwen3Next is only available in transformers>=4.57.0 from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig @@ -588,6 +599,43 @@ ), ) +if GPT_OSS_AVAILABLE: + MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss, + model_class=GptOssForCausalLM, + mini_model_config=GptOssConfig( + vocab_size=32000, # 201088 + hidden_size=896, + intermediate_size=896, # Same as hidden_size for GPT-OSS + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_parameters={ + "rope_type": "yarn", + "factor": 8.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + "original_max_position_embeddings": 4096, + }, + attention_dropout=0.0, + num_local_experts=8, # Reduced from 32 for mini model + num_experts_per_tok=2, # Reduced from 4 for mini model + router_aux_loss_coef=0.9, + output_router_logits=False, + sliding_window=128, + layer_types=["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(4)], + ), + ) + if GEMMA3_AVAILABLE: MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, diff --git a/test/utils.py b/test/utils.py index e4fd271dd..14ec9d222 100644 --- a/test/utils.py +++ b/test/utils.py @@ -482,6 +482,18 @@ def revert_liger_kernel_to_qwen3_moe(model_config: MiniModelConfig): print("Liger kernel patches have been reverted.") +def revert_liger_kernel_to_gpt_oss(model_config: MiniModelConfig): + """ + Revert all Liger kernel patches applied to GPT-OSS. + """ + from transformers.models.gpt_oss import modeling_gpt_oss + + importlib.reload(modeling_gpt_oss) + model_config.model_class = modeling_gpt_oss.GptOssForCausalLM + + print("Liger kernel patches have been reverted.") + + def revert_liger_kernel_to_qwen2_vl(model_config: MiniModelConfig): """ Revert all Liger kernel patches applied to Qwen2-VL. From f2e25e3052c8c129995a5b580c723604a3ffe846 Mon Sep 17 00:00:00 2001 From: yeshsurya Date: Mon, 3 Nov 2025 00:17:22 +0530 Subject: [PATCH 2/5] [feat]: completing test invocation --- test/convergence/bf16/test_mini_models.py | 19 +++++++ test/convergence/fp32/test_mini_models.py | 64 +++++++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/test/convergence/bf16/test_mini_models.py b/test/convergence/bf16/test_mini_models.py index a038bfbdd..cac6012ea 100644 --- a/test/convergence/bf16/test_mini_models.py +++ b/test/convergence/bf16/test_mini_models.py @@ -1595,6 +1595,25 @@ def run_mini_model( ), ], ), + pytest.param( + "mini_gpt_oss", + 32, + 1e-5, + torch.bfloat16, + 5e-2, + 5e-2, + 1e-1, + 1e-1, + 1e-2, + 1e-2, + marks=[ + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), + pytest.mark.skipif( + not GPT_OSS_AVAILABLE, + reason="GPT-OSS not available in this version of transformers", + ), + ], + ), pytest.param( "mini_qwen2_vl", 32, diff --git a/test/convergence/fp32/test_mini_models.py b/test/convergence/fp32/test_mini_models.py index 1c7fc88d7..80a0cec7d 100644 --- a/test/convergence/fp32/test_mini_models.py +++ b/test/convergence/fp32/test_mini_models.py @@ -29,6 +29,7 @@ from liger_kernel.transformers import apply_liger_kernel_to_glm4 from liger_kernel.transformers import apply_liger_kernel_to_glm4v from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe +from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss from liger_kernel.transformers import apply_liger_kernel_to_granite from liger_kernel.transformers import apply_liger_kernel_to_internvl from liger_kernel.transformers import apply_liger_kernel_to_llama @@ -63,6 +64,7 @@ from test.utils import revert_liger_kernel_to_glm4 from test.utils import revert_liger_kernel_to_glm4v from test.utils import revert_liger_kernel_to_glm4v_moe +from test.utils import revert_liger_kernel_to_gpt_oss from test.utils import revert_liger_kernel_to_granite from test.utils import revert_liger_kernel_to_internvl from test.utils import revert_liger_kernel_to_llama @@ -236,6 +238,15 @@ except ImportError: QWEN3_AVAILABLE = False +try: + # GPT-OSS is only available in transformers>=4.55.0 + from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + GPT_OSS_AVAILABLE = True +except ImportError: + GPT_OSS_AVAILABLE = False + try: # InternVL is only available in transformers>=4.52.1 from transformers.models.internvl.configuration_internvl import InternVLConfig @@ -589,6 +600,43 @@ ), ) +if GPT_OSS_AVAILABLE: + MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig( + liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss, + liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss, + model_class=GptOssForCausalLM, + mini_model_config=GptOssConfig( + vocab_size=32000, # 201088 + hidden_size=896, + intermediate_size=896, # Same as hidden_size for GPT-OSS + num_hidden_layers=4, + num_attention_heads=8, + num_key_value_heads=2, + head_dim=64, + hidden_act="silu", + max_position_embeddings=8192, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_parameters={ + "rope_type": "yarn", + "factor": 8.0, + "beta_fast": 32.0, + "beta_slow": 1.0, + "truncate": False, + "original_max_position_embeddings": 4096, + }, + attention_dropout=0.0, + num_local_experts=8, # Reduced from 32 for mini model + num_experts_per_tok=2, # Reduced from 4 for mini model + router_aux_loss_coef=0.9, + output_router_logits=False, + sliding_window=128, + layer_types=["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(4)], + ), + ) + if GEMMA3_AVAILABLE: MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig( liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text, @@ -1501,6 +1549,22 @@ def run_mini_model( reason="Qwen3 not available in this version of transformers", ), ), + pytest.param( + "mini_gpt_oss", + 32, + 1e-5, + torch.float32, + 1e-8, + 1e-5, + 5e-3, + 1e-5, + 5e-3, + 1e-5, + marks=pytest.mark.skipif( + not GPT_OSS_AVAILABLE, + reason="GPT-OSS not available in this version of transformers", + ), + ), pytest.param( # qwen2_vl requires slightly larger tolerances to pass this test after bug fix to qwen2_vl in transformers v4.47.0 "mini_qwen2_vl", 32, From 6eef643bc846a4353a661c9e3550a41ef1d27bd4 Mon Sep 17 00:00:00 2001 From: Yeshwanth Nagaraj Date: Sat, 22 Nov 2025 12:42:28 +0000 Subject: [PATCH 3/5] [chrome]: style compliance --- src/liger_kernel/transformers/monkey_patch.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/liger_kernel/transformers/monkey_patch.py b/src/liger_kernel/transformers/monkey_patch.py index 46efacb63..8aa4ce312 100755 --- a/src/liger_kernel/transformers/monkey_patch.py +++ b/src/liger_kernel/transformers/monkey_patch.py @@ -1498,8 +1498,6 @@ def apply_liger_kernel_to_gpt_oss( from transformers.models.gpt_oss import modeling_gpt_oss from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel - from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward - if rope: modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb From 6bc115891759b5c3da79ae9a5391a3a5b272f44e Mon Sep 17 00:00:00 2001 From: Yeshwanth Nagaraj Date: Sat, 22 Nov 2025 13:05:20 +0000 Subject: [PATCH 4/5] [doc]: Adding to readme MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added GPT-OSS to the supported models table in README.md with its supported operations (RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy). 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 55052dd71..8d8ed3858 100644 --- a/README.md +++ b/README.md @@ -264,6 +264,7 @@ loss.backward() | OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | +| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy | | InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | | HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy | From f68437e7c75fadd43d9c6946c814c7dced678e06 Mon Sep 17 00:00:00 2001 From: yeshsurya Date: Sun, 23 Nov 2025 13:20:26 +0530 Subject: [PATCH 5/5] [update]: unpack result into tuple --- src/liger_kernel/transformers/model/gpt_oss.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/liger_kernel/transformers/model/gpt_oss.py b/src/liger_kernel/transformers/model/gpt_oss.py index de258e0b9..9a277da52 100644 --- a/src/liger_kernel/transformers/model/gpt_oss.py +++ b/src/liger_kernel/transformers/model/gpt_oss.py @@ -4,11 +4,12 @@ import torch -from transformers.modeling_outputs import MoeCausalLMOutputWithPast from transformers.modeling_outputs import MoeModelOutputWithPast from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss +from liger_kernel.transformers.model.loss_utils import unpack_cross_entropy_result +from liger_kernel.transformers.model.output_classes import LigerMoeCausalLMOutputWithPast def lce_forward( @@ -27,7 +28,7 @@ def lce_forward( logits_to_keep: Union[int, torch.Tensor] = 0, skip_logits: Optional[bool] = None, **kwargs, -) -> MoeCausalLMOutputWithPast: +) -> LigerMoeCausalLMOutputWithPast: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., @@ -92,12 +93,13 @@ def lce_forward( shift_labels = kwargs.pop("shift_labels", None) logits = None loss = None + token_accuracy = None if skip_logits is None: skip_logits = self.training and (labels is not None or shift_labels is not None) if skip_logits: - loss = LigerForCausalLMLoss( + result = LigerForCausalLMLoss( hidden_states=kept_hidden_states, lm_head_weight=self.lm_head.weight, labels=labels, @@ -105,6 +107,7 @@ def lce_forward( hidden_size=self.config.hidden_size, **kwargs, ) + loss, _, token_accuracy = unpack_cross_entropy_result(result) else: # if in inference model materialize logits logits = self.lm_head(kept_hidden_states) if labels is not None or shift_labels is not None: @@ -127,7 +130,7 @@ def lce_forward( if labels is not None: loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device - return MoeCausalLMOutputWithPast( + return LigerMoeCausalLMOutputWithPast( loss=loss, aux_loss=aux_loss, logits=logits, @@ -135,4 +138,5 @@ def lce_forward( hidden_states=outputs.hidden_states, attentions=outputs.attentions, router_logits=outputs.router_logits, + token_accuracy=token_accuracy, )