Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_gemma3_text # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_glm4v_moe # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_granite # noqa: F401
Expand Down Expand Up @@ -109,6 +110,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4_moe",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_granite",
"apply_liger_kernel_to_internvl",
Expand Down Expand Up @@ -185,6 +187,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_gemma3",
"apply_liger_kernel_to_gemma3_text",
"apply_liger_kernel_to_glm4",
"apply_liger_kernel_to_glm4_moe",
"apply_liger_kernel_to_glm4v",
"apply_liger_kernel_to_glm4v_moe",
"apply_liger_kernel_to_granite",
Expand Down
154 changes: 154 additions & 0 deletions src/liger_kernel/transformers/model/glm4_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.utils.deprecation import deprecate_kwarg

from liger_kernel.transformers.model.loss_utils import LigerForCausalLMLoss
from liger_kernel.transformers.model.output_classes import LigerCausalLMOutputWithPast


@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
def lce_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
skip_logits: Optional[bool] = None,
**kwargs,
) -> Union[Tuple, LigerCausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.


logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).

Example:

```python
>>> from transformers import AutoProcessor, Glm4MoeForCausalLM
>>> import torch

>>> MODEL_PATH = "meta-glm4_moe/Glm4Moe-2-7b-hf"
>>> messages = [
{
"role": "user",
"content": [
{
"type": "image",
"url": "https://upload.wikimedia.org/wikipedia/commons/f/fa/Grayscale_8bits_palette_sample_image.png"
},
{
"type": "text",
"text": "describe this image"
}
],
}
]
>>> processor = AutoProcessor.from_pretrained(MODEL_PATH)
>>> model = Glm4MoeForCausalLM.from_pretrained(
pretrained_model_name_or_path=MODEL_PATH,
dtype="auto",
device_map="auto",
)
>>> inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt"
).to(model.device)
>>> inputs.pop("token_type_ids", None)
>>> generated_ids = model.generate(**inputs, max_new_tokens=8192)
>>> output_text = processor.decode(generated_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=False)
```
"""

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
kept_hidden_states = hidden_states[:, slice_indices, :]

shift_labels = kwargs.pop("shift_labels", None)
logits = None
loss = None
token_accuracy = None

if skip_logits and labels is None and shift_labels is None:
raise ValueError("skip_logits is True, but labels and shift_labels are None")

if skip_logits is None:
# By default, if in training mode, don't materialize logits
skip_logits = self.training and (labels is not None or shift_labels is not None)

if skip_logits:
loss = LigerForCausalLMLoss(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kindly have a look at the other model examples and adapt to new API that returns the metric

hidden_states=kept_hidden_states,
lm_head_weight=self.lm_head.weight,
labels=labels,
shift_labels=shift_labels,
hidden_size=self.config.hidden_size,
**kwargs,
)

else:
logits = self.lm_head(kept_hidden_states)
if labels is not None or shift_labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
shift_labels=shift_labels,
vocab_size=self.config.vocab_size,
**kwargs,
)

return LigerCausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
token_accuracy=token_accuracy,
)
76 changes: 76 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,6 +2065,81 @@ def apply_liger_kernel_to_glm4(
_patch_rms_norm_module(decoder_layer.post_mlp_layernorm, in_place=False)


def apply_liger_kernel_to_glm4_moe(
rope: bool = False,
cross_entropy: bool = False,
fused_linear_cross_entropy: bool = True,
rms_norm: bool = True,
swiglu: bool = True,
model: PreTrainedModel = None,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace GLM-4MOE models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is False.
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
fused_linear_cross_entropy (bool):
Whether to apply Liger's fused linear cross entropy loss. Default is True.
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU Glm4MLP. Default is True.
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
loaded. Default is None.
"""
assert not (cross_entropy and fused_linear_cross_entropy), (
"cross_entropy and fused_linear_cross_entropy cannot both be True."
)
from transformers.models.glm4_moe import modeling_glm4_moe
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeModel
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeMoE

from liger_kernel.transformers.model.glm4_moe import lce_forward as glm4_moe_lce_forward
from liger_kernel.transformers.rms_norm import LigerRMSNormForGlm4

if rope:
modeling_glm4_moe.apply_rotary_pos_emb = liger_rotary_pos_emb_with_cast
if rms_norm:
modeling_glm4_moe.Glm4MoeRMSNorm = LigerRMSNormForGlm4
if cross_entropy:
from transformers.loss.loss_utils import nn

nn.functional.cross_entropy = liger_cross_entropy
if fused_linear_cross_entropy:
if model is not None:
model.forward = MethodType(glm4_moe_lce_forward, model)
else:
modeling_glm4_moe.Glm4MoeForCausalLM.forward = glm4_moe_lce_forward

if model is not None:
# The model instance already exists, so we need to additionally patch the
# instance variables that reference already-instantiated modules

# get the base model from the model instance
base_model: Glm4MoeModel = getattr(model, model.base_model_prefix, model)

if rms_norm:
_patch_rms_norm_module(base_model.norm)

for decoder_layer in base_model.layers:
if swiglu:
_patch_swiglu_module(decoder_layer.mlp, LigerSwiGLUMLP)
if rms_norm:
_patch_rms_norm_module(decoder_layer.input_layernorm)
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)
# patch MOE layers
if isinstance(decoder_layer.mlp, Glm4MoeMoE):
experts = decoder_layer.mlp.experts
if experts is not None:
for expert in experts:
_patch_swiglu_module(expert, LigerSwiGLUMLP)

shared_experts = decoder_layer.mlp.shared_experts
if shared_experts is not None:
_patch_swiglu_module(shared_experts, LigerSwiGLUMLP)


def apply_liger_kernel_to_glm4v(
rope: bool = False,
cross_entropy: bool = False,
Expand Down Expand Up @@ -2750,6 +2825,7 @@ def apply_liger_kernel_to_hunyuan_v1_moe(
"gemma3_text": apply_liger_kernel_to_gemma3_text,
"gemma3": apply_liger_kernel_to_gemma3,
"glm4": apply_liger_kernel_to_glm4,
"glm4_moe": apply_liger_kernel_to_glm4_moe,
"glm4v": apply_liger_kernel_to_glm4v,
"glm4v_moe": apply_liger_kernel_to_glm4v_moe,
"internvl": apply_liger_kernel_to_internvl,
Expand Down
60 changes: 60 additions & 0 deletions test/convergence/bf16/test_mini_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text
from liger_kernel.transformers import apply_liger_kernel_to_glm4
from liger_kernel.transformers import apply_liger_kernel_to_glm4_moe
from liger_kernel.transformers import apply_liger_kernel_to_glm4v
from liger_kernel.transformers import apply_liger_kernel_to_glm4v_moe
from liger_kernel.transformers import apply_liger_kernel_to_granite
Expand Down Expand Up @@ -62,6 +63,7 @@
from test.utils import revert_liger_kernel_to_gemma2
from test.utils import revert_liger_kernel_to_gemma3_text
from test.utils import revert_liger_kernel_to_glm4
from test.utils import revert_liger_kernel_to_glm4_moe
from test.utils import revert_liger_kernel_to_glm4v
from test.utils import revert_liger_kernel_to_glm4v_moe
from test.utils import revert_liger_kernel_to_granite
Expand Down Expand Up @@ -214,6 +216,14 @@
except ImportError:
GLM4_AVAILABLE = False

try:
from transformers.models.glm4_moe.configuration_glm4_moe import Glm4MoeConfig
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeForCausalLM

GLM4_MOE_AVAILABLE = True
except ImportError:
GLM4_MOE_AVAILABLE = False

try:
# Glm4v is only available in transformers>=4.51.3
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig
Expand Down Expand Up @@ -1080,6 +1090,37 @@
),
)

if GLM4_MOE_AVAILABLE:
MINI_MODEL_SETUPS["mini_glm4_moe"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_glm4_moe,
liger_kernel_patch_revert_func=revert_liger_kernel_to_glm4_moe,
model_class=Glm4MoeForCausalLM,
mini_model_config=Glm4MoeConfig(
bos_token_id=1, # None
eos_token_id=2, # 151329, 151336, 151338
pad_token_id=2, # 151329
partial_rotary_factor=0.5,
cross_attention_layers=None,
dropout=0,
hidden_act="silu",
hidden_size=1024, # 6144
initializer_range=0.02,
intermediate_size=2048, # 14336
max_position_embeddings=4096, # 32768
num_attention_heads=8, # 48
num_hidden_layers=4, # 61
num_key_value_heads=2,
rms_norm_eps=1e-5,
rope_scaling=None,
rope_theta=500_000,
tie_word_embeddings=False,
use_cache=True,
vocab_size=32000, # 151552
attention_bias=True,
attn_implementation="sdpa", # default value, pytorch native attention
),
)

if GLM4V_AVAILABLE:
MINI_MODEL_SETUPS["mini_glm4v"] = MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_glm4v,
Expand Down Expand Up @@ -1829,6 +1870,25 @@ def run_mini_model(
),
],
),
pytest.param(
"mini_glm4_moe",
32,
1e-5,
torch.bfloat16,
1e-2,
1e-2,
1e-1,
1e-2,
1e-2,
1e-2,
marks=[
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
pytest.mark.skipif(
not GLM4_MOE_AVAILABLE,
reason="Glm4_moe not available in this version of transformers",
),
],
),
pytest.param(
"mini_glm4v",
32,
Expand Down
Loading
Loading