Skip to content
Draft
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gpt_oss # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_llama4 # noqa: F401
Expand Down Expand Up @@ -89,6 +90,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down Expand Up @@ -148,6 +150,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_gpt_oss",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_llama",
"apply_liger_kernel_to_llava",
Expand Down
113 changes: 113 additions & 0 deletions src/liger_kernel/transformers/model/gpt_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Optional
from typing import Union

import torch

from transformers.cache_utils import Cache
from transformers.modeling_outputs import MoeCausalLMOutputWithPast
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.models.mixtral.modeling_mixtral import load_balancing_loss_func
from transformers.processing_utils import Unpack
from transformers.utils import TransformersKwargs

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss


def lce_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_router_logits: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[TransformersKwargs],
) -> MoeCausalLMOutputWithPast:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

Example:

```python
>>> from transformers import AutoTokenizer, GptOssForCausalLM

>>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1")
>>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1")

>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")

>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""

output_router_logits = (
output_router_logits if output_router_logits is not None else self.config.output_router_logits
)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: MoeModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_router_logits=output_router_logits,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None

# if in training mode, do not materialize logits
if self.training and (labels is not None or shift_labels is not None):
loss = LigerForCausalLMLoss(
hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)
else: # if in inference model materialize logits
logits = self.lm_head(kept_hidden_states)
if labels is not None:
loss = self.loss_function(logits, labels, self.vocab_size, **kwargs)

aux_loss = None
if output_router_logits:
aux_loss = load_balancing_loss_func(
outputs.router_logits,
self.num_experts,
self.num_experts_per_tok,
attention_mask,
)
if labels is not None:
loss += self.router_aux_loss_coef * aux_loss.to(loss.device) # make sure to reside in the same device

return MoeCausalLMOutputWithPast(
loss=loss,
aux_loss=aux_loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
router_logits=outputs.router_logits,
)
53 changes: 53 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1839,13 +1839,66 @@ def apply_liger_kernel_to_glm4(
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)


def apply_liger_kernel_to_gpt_oss(
rope: bool = True,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if swiglu can't be implemented now, let's set to False by default and raise NotImplementedError if set to True

model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GPT OSS models.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)

from transformers.models.gpt_oss import modeling_gpt_oss
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel

from liger_kernel.transformers.model.gpt_oss import lce_forward as gpt_oss_lce_forward

if rope:
modeling_gpt_oss.apply_rotary_pos_emb = liger_rotary_pos_emb

if rms_norm:
modeling_gpt_oss.GptOssRMSNorm = LigerRMSNorm

if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy

if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(gpt_oss_lce_forward, model)
else:
modeling_gpt_oss.GptOssForCausalLM.forward = gpt_oss_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: GptOssModel = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)
for decoder_layer in base_model.layers:
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


# Model type corresponds to the keys defined in transformers/models/auto/modeling_auto.py
MODEL_TYPE_TO_APPLY_LIGER_FN = {
"gemma": apply_liger_kernel_to_gemma,
"gemma2": apply_liger_kernel_to_gemma2,
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"gpt_oss": apply_liger_kernel_to_gpt_oss,
"llama": apply_liger_kernel_to_llama,
"llama4_text": apply_liger_kernel_to_llama4,
"llama4": apply_liger_kernel_to_llama4,
Expand Down
71 changes: 68 additions & 3 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
from liger_kernel.transformers import apply_liger_kernel_to_glm4
from liger_kernel.transformers import apply_liger_kernel_to_gpt_oss
from liger_kernel.transformers import apply_liger_kernel_to_granite
from liger_kernel.transformers import apply_liger_kernel_to_llama
from liger_kernel.transformers import apply_liger_kernel_to_llama4
Expand All @@ -46,6 +47,7 @@
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
from test.utils import revert_liger_kernel_to_glm4
from test.utils import revert_liger_kernel_to_gpt_oss
from test.utils import revert_liger_kernel_to_granite
from test.utils import revert_liger_kernel_to_llama
from test.utils import revert_liger_kernel_to_llama4
Expand Down Expand Up @@ -168,6 +170,15 @@
except ImportError:
SMOLLM3_AVAILABLE = False

try:
# GPT OSS is only available in transformers>=4.55.0
from transformers.models.gpt_oss.configuration_gpt_oss import GptOssConfig
from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM

GPT_OSS_AVAILABLE = True
except ImportError:
GPT_OSS_AVAILABLE = False

from liger_kernel.utils import infer_device

device = infer_device()
Expand Down Expand Up @@ -856,6 +867,43 @@
),
)

if GPT_OSS_AVAILABLE:
MINI_MODEL_SETUPS["mini_gpt_oss"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gpt_oss,
liger_kernel_patch_revert_func=revert_liger_kernel_to_gpt_oss,
model_class=GptOssForCausalLM,
mini_model_config=GptOssConfig(
num_hidden_layers=4,
num_local_experts=32, # 128,
vocab_size=32000, # 201088,
hidden_size=896, # 2880,
intermediate_size=896, # 2880,
head_dim=64,
num_attention_heads=8, # 16,
num_key_value_heads=2, # 4,
sliding_window=128,
rope_theta=150000.0,
tie_word_embeddings=False,
hidden_act="silu",
initializer_range=0.02,
max_position_embeddings=32768, # 131072,
rms_norm_eps=1e-5,
rope_scaling=dict(
factor=32.0,
beta_fast=32.0,
beta_slow=1.0,
truncate=False,
rope_type="yarn",
),
attention_dropout=0.0,
num_experts_per_tok=2,
router_aux_loss_coef=0.9,
output_router_logits=False,
use_cache=True,
layer_types=None,
),
)


def create_model(model_name="mini_llama4"):
"""
Expand Down Expand Up @@ -1281,9 +1329,7 @@ def run_mini_model(
# 1e-2,
# 1e-2,
# 1e-2,
# marks=pytest.mark.skipif(
# not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
# ),
# marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
# ),
pytest.param(
"mini_gemma3_text",
Expand All @@ -1304,6 +1350,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_gpt_oss",
32,
1e-5,
torch.bfloat16,
1e-2,
5e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GPT_OSS_AVAILABLE,
reason="GPT OSS not available in this version of transformers",
),
],
),
],
)
def test_mini_model(
Expand Down
Loading
Loading