Skip to content

Conversation

@DonghakPark
Copy link
Member

Dependency of the PR

  • None

Commits to be reviewed in this PR

[CausalLM] Implement ERNIE's GLM Style RoPE
[CausalLM] Implement ERNIE's GLM Style RoPE

Implement GLM Sytle RoPE at MHA CORE

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
[CausalLM] Add ernie to main & meson build
[CausalLM] Add ernie to main & meson build

Add ernie model & Layer to main, meson build

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]
[CausalLM] Implement Erine MoE Layer
[CausalLM] Implement Erine MoE Layer

Implement Ernie MoE Layer
- Shared Expert accum
- static bias add

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
[CausalLM] Add causallm common properties
[CausalLM] Add causallm common properties <num_shared experts, moe_norm_min>

add causallm common properties
- num_shared_experts
- moe_norm_min

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
[Application][CausalLM] Implement Ernie 4.5 MoE Model
[Application][CausalLM] Implement Ernie 4.5 MoE Model

Implemnet Ernie 4.5 MoE Model
- ernie's first layer is dense
- ernie has shared expert at each MoE Layer

**Self evaluation:**
1. Build test:   [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>

Summary

This PR implements and integrates the Ernie 4.5 MoE (Mixture of Experts) model into the CausalLM application.
The changes include the implementation of the model structure, the MoE layer, GLM-style RoPE, and build system integration.

Erine 4.5 Models Key Difference btw Qwen

1. Qwen apply RMSNorm before RoPE to Q, K but Ernie don't

2. Ernie Apply GLM Style RoPE

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    return torch.stack((-x2, x1), dim=-1).flatten(-2)


def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    # glm rope style (with full dim) and full precision
    original_dtype = q.dtype

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)

    # Interleave them instead of usual shape
    cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
    sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)

    q_embed = (q.float() * cos) + (rotate_half(q).float() * sin)
    k_embed = (k.float() * cos) + (rotate_half(k).float() * sin)

    return q_embed.to(original_dtype), k_embed.to(original_dtype)

3. Ernie use 2 Shared Expert + 6 TopK Expert (in 21B-A3B Model)

4. Ernie Has e_score_correction_bias & add this bias after softmax of router

router_logits = F.linear(hidden_states.float(), self.weight)
router_logits = F.softmax(router_logits, dim=1, dtype=torch.float)
router_top_value, router_indices = torch.topk(self.moe_statics(router_logits), self.top_k, dim=-1)

Signed-off-by: Donghak PARK [email protected]

Implemnet Ernie 4.5 MoE Model
- ernie's first layer is dense
- ernie has shared expert at each MoE Layer

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
…rm_min>

add causallm common properties
- num_shared_experts
- moe_norm_min

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
Implement Ernie MoE Layer
- Shared Expert accum
- static bias add

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
Add ernie model & Layer to main, meson build

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
Implement GLM Sytle RoPE at MHA CORE

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
Avoid Race Condition on eviction experts

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghak PARK <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants