Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ loss.backward()
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Olmo3 | `liger_kernel.transformers.apply_liger_kernel_to_olmo3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GPT-OSS | `liger_kernel.transformers.apply_liger_kernel_to_gpt_oss` | RoPE, RMSNorm, CrossEntropyLoss, FusedLinearCrossEntropy |
| InternVL3 | `liger_kernel.transformers.apply_liger_kernel_to_internvl` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| HunyuanV1 | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_dense` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| HunyuanV1 MoE | `liger_kernel.transformers.apply_liger_kernel_to_hunyuan_v1_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_dense # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_hunyuan_v1_moe # noqa: F401
Expand Down Expand Up @@ -110,6 +111,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
Expand Down Expand Up @@ -187,6 +189,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
"apply_liger_kernel_to_llama",
Expand Down
138 changes: 138 additions & 0 deletions src/liger_kernel/transformers/model/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
from typing import List
from typing import Optional
from typing import Union

import torch

from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Returns:

Example:

```python
>>> from transformers import AutoTokenizer, GptOssForCausalLM

>>> model = GptOssForCausalLM.from_pretrained("openai/gpt-oss-20b")
>>> tokenizer = AutoTokenizer.from_pretrained("openai/gpt-oss-20b")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
Comment on lines +45 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this does not seem to match the output of this function


output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None

if skip_logits is None:
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.vocab_size,
**kwargs,
)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
75 changes: 75 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from liger_kernel.transformers.model.gemma import lce_forward_deprecated as gemma_lce_forward_deprecated
from liger_kernel.transformers.model.gemma2 import lce_forward as gemma2_lce_forward
from liger_kernel.transformers.model.gemma2 import lce_forward_deprecated as gemma2_lce_forward_deprected
from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward
from liger_kernel.transformers.model.llama import lce_forward as llama_lce_forward
from liger_kernel.transformers.model.llama import lce_forward_deprecated as llama_lce_forward_deprecated
from liger_kernel.transformers.model.llava import lce_forward as llava_lce_forward
Expand Down Expand Up @@ -1459,6 +1460,79 @@ def apply_liger_kernel_to_qwen3_moe(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_gpt_oss(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = False, # Set to False by default since GPT-OSS has custom expert implementation
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GPT-OSS models.
NOTE: GPT-OSS is supported in transformers >= 4.55.0
NOTE: SwiGLU patching is disabled by default for GPT-OSS as it uses a custom expert
implementation with clamping and MXFP4 quantization.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is False.
Note: GPT-OSS uses a custom expert implementation, so SwiGLU patching is disabled by default.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
if version.parse(transformers.__version__) < version.parse("4.55.0"):
logger.warning("GPT-OSS support requires transformers >= 4.55.0")
return

assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.gpt_oss import modeling_gpt_oss
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel

if rope:
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb

if rms_norm:
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gpt_oss_lce_forward, model)
else:
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward

# Note: SwiGLU patching is not implemented for GPT-OSS due to custom expert implementation
# with clamping (swiglu_limit=7.0) and MXFP4 quantization

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_qwen2_vl(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -2752,6 +2826,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
"glm4": apply_liger_kernel_to_glm4,
"glm4v": apply_liger_kernel_to_glm4v,
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
"gpt_oss": apply_liger_kernel_to_gpt_oss,
"internvl": apply_liger_kernel_to_internvl,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
Expand Down
67 changes: 67 additions & 0 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_glm4
from liger_kernel.transformers import apply_liger_kernel_to_glm4v
from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_dense
from liger_kernel.transformers import apply_liger_kernel_to_hunyuan_v1_moe
Expand Down Expand Up @@ -64,6 +65,7 @@
from test.utils import revert_liger_kernel_to_glm4
from test.utils import revert_liger_kernel_to_glm4v
from test.utils import revert_liger_kernel_to_glm4v_moe
from test.utils import revert_liger_kernel_to_gpt_oss
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_hunyuan_v1
from test.utils import revert_liger_kernel_to_hunyuan_v1_moe
Expand Down Expand Up @@ -267,6 +269,15 @@
except ImportError:
FALCONH1_AVAILABLE = False

try:
# GPT-OSS is only available in transformers>=4.55.0
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

GPT_OSS_AVAILABLE = True
except ImportError:
GPT_OSS_AVAILABLE = False

try:
# Qwen3Next is only available in transformers>=4.57.0
from transformers.models.qwen3_next.configuration_qwen3_next import Qwen3NextConfig
Expand Down Expand Up @@ -613,6 +624,43 @@
),
)

if GPT_OSS_AVAILABLE:
MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss,
liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss,
model_class=GptOssForCausalLM,
mini_model_config=GptOssConfig(
vocab_size=32000, # 201088
hidden_size=896,
intermediate_size=896, # Same as hidden_size for GPT-OSS
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=2,
head_dim=64,
hidden_act="silu",
max_position_embeddings=8192,
initializer_range=0.02,
rms_norm_eps=1e-5,
use_cache=True,
tie_word_embeddings=False,
rope_parameters={
"rope_type": "yarn",
"factor": 8.0,
"beta_fast": 32.0,
"beta_slow": 1.0,
"truncate": False,
"original_max_position_embeddings": 4096,
},
attention_dropout=0.0,
num_local_experts=8, # Reduced from 32 for mini model
num_experts_per_tok=2, # Reduced from 4 for mini model
router_aux_loss_coef=0.9,
output_router_logits=False,
sliding_window=128,
layer_types=["sliding_attention" if bool((i + 1) % 2) else "full_attention" for i in range(4)],
),
)

if GEMMA3_AVAILABLE:
MINI_MODEL_SETUPS["mini_gemma3_text"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gemma3_text,
Expand Down Expand Up @@ -1664,6 +1712,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_gpt_oss",
32,
1e-5,
torch.bfloat16,
5e-2,
5e-2,
1e-1,
1e-1,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GPT_OSS_AVAILABLE,
reason="GPT-OSS not available in this version of transformers",
),
],
),
pytest.param(
"mini_qwen2_vl",
32,
Expand Down
Loading
Loading