Skip to content

Commit 1ae79d2

Browse files
committed
Add AFMoE model support
1 parent 2b8068c commit 1ae79d2

File tree

11 files changed

+1955
-0
lines changed

11 files changed

+1955
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,8 @@
382382
title: Main Classes
383383
- sections:
384384
- sections:
385+
- local: model_doc/afmoe
386+
title: AFMoE
385387
- local: model_doc/albert
386388
title: ALBERT
387389
- local: model_doc/apertus

docs/source/en/model_doc/afmoe.md

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
<!--Copyright 2025 Arcee AI and The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
<div style="float: right;">
18+
<div class="flex flex-wrap space-x-1">
19+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
20+
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
21+
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
22+
</div>
23+
</div>
24+
25+
# AFMoE
26+
27+
AFMoE (Arcee Foundational Mixture of Experts) is a decoder-only transformer model that extends the Llama architecture with a sparse Mixture of Experts (MoE) approach. The model combines token-choice routing with shared experts and employs several architectural innovations for efficient inference and improved performance.
28+
29+
## Key Architecture Features
30+
31+
AFMoE introduces several key modifications to the standard transformer architecture:
32+
33+
- **Mixture of Experts with Shared Experts**: Combines routed experts (activated per-token via learned routing) with always-active shared experts for stable base computation
34+
- **Token-Choice Routing**: Uses sigmoid or softmax-based routing with normalization and scaling for expert selection
35+
- **Q/K Normalization and Gating**: Applies RMSNorm to query and key projections and uses sigmoid gating on attention outputs for improved stability
36+
- **Hybrid Attention Patterns**: Alternates between sliding window attention and full attention across layers for efficiency with long contexts
37+
- **Dual Normalization**: Uses pre- and post-normalization around both attention and MLP blocks for training stability
38+
- **Configurable Dense Layers**: Allows initial layers to use dense MLPs before transitioning to sparse MoE layers
39+
40+
The model supports extended context lengths with RoPE embeddings and includes all standard Transformers features including Flash Attention 2, SDPA, gradient checkpointing, and quantization support.
41+
42+
> [!TIP]
43+
> AFMoE is particularly well-suited for scenarios requiring efficient scaling through sparsity while maintaining strong performance. The shared experts provide a stable computation baseline while routed experts enable model capacity scaling.
44+
45+
The example below demonstrates how to generate text with AFMoE using [`Pipeline`] or the [`AutoModel`].
46+
47+
<hfoptions id="usage">
48+
<hfoption id="Pipeline">
49+
50+
```py
51+
import torch
52+
from transformers import pipeline
53+
54+
pipeline = pipeline(
55+
task="text-generation",
56+
model="arcee-ai/Trinity-Mini",
57+
torch_dtype=torch.bfloat16,
58+
device=0
59+
)
60+
61+
output = pipeline("The key innovation in mixture of experts is")
62+
print(output[0]["generated_text"])
63+
```
64+
65+
</hfoption>
66+
<hfoption id="AutoModel">
67+
68+
```py
69+
import torch
70+
from transformers import AutoTokenizer, AfmoeForCausalLM
71+
72+
tokenizer = AutoTokenizer.from_pretrained("arcee-ai/Trinity-Mini")
73+
model = AfmoeForCausalLM.from_pretrained(
74+
"arcee-ai/Trinity-Mini",
75+
torch_dtype=torch.bfloat16,
76+
device_map="auto"
77+
)
78+
79+
inputs = tokenizer("The key innovation in mixture of experts is", return_tensors="pt")
80+
with torch.no_grad():
81+
outputs = model.generate(**inputs, max_new_tokens=50)
82+
83+
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
84+
```
85+
86+
</hfoption>
87+
</hfoptions>
88+
89+
## Model Architecture Details
90+
91+
### Expert Routing
92+
93+
AFMoE uses token-choice routing where each token independently selects top-k experts based on router logits. The routing mechanism includes:
94+
95+
- Configurable scoring function (sigmoid or softmax)
96+
- Optional route normalization for balanced expert utilization
97+
- Route scaling to control expert contribution strength
98+
- Bias correction for expert selection
99+
100+
### Shared Experts
101+
102+
Unlike standard MoE models, AFMoE includes shared experts that are always activated for every token, providing:
103+
104+
- A stable computation baseline across all tokens
105+
- Reduced variance in model outputs
106+
- Better handling of out-of-distribution inputs
107+
108+
### Attention Mechanism
109+
110+
The hybrid attention pattern alternates between:
111+
112+
- **Sliding Window Attention**: For efficiency on long sequences, with configurable window size
113+
- **Full Attention**: Applied every N layers (configurable via `global_attn_every_n_layers`) for global context
114+
115+
All attention layers include Q/K normalization and output gating for improved training dynamics.
116+
117+
## AfmoeConfig
118+
119+
[[autodoc]] AfmoeConfig
120+
121+
## AfmoeModel
122+
123+
[[autodoc]] AfmoeModel
124+
- forward
125+
126+
## AfmoeForCausalLM
127+
128+
[[autodoc]] AfmoeForCausalLM
129+
- forward
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_afmoe import *
22+
from .modeling_afmoe import *
23+
else:
24+
import sys
25+
26+
_file = globals()["__file__"]
27+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
28+
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# coding=utf-8
2+
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""AFMoE model configuration"""
16+
17+
from typing import Optional
18+
19+
from ...configuration_utils import PreTrainedConfig, layer_type_validation
20+
from ...modeling_rope_utils import rope_config_validation, standardize_rope_params
21+
from ...utils import logging
22+
23+
24+
logger = logging.get_logger(__name__)
25+
26+
27+
class AfmoeConfig(PreTrainedConfig):
28+
r"""
29+
This is the configuration class to store the configuration of a [`AfmoeModel`]. It is used to instantiate an
30+
AFMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
31+
with the defaults will yield a similar configuration to that of [arcee-ai/Trinity-Mini](https://huggingface.co/arcee-ai/Trinity-Mini).
32+
33+
AFMoE is an Adaptive Feedforward MoE (Mixture of Experts) model with token-choice routing, shared experts, and a
34+
hybrid attention mechanism combining sliding window and full attention patterns.
35+
36+
Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
37+
documentation from [`PreTrainedConfig`] for more information.
38+
39+
Args:
40+
vocab_size (`int`, *optional*, defaults to 200192):
41+
Vocabulary size of the AFMoE model. Defines the number of different tokens that can be represented by the
42+
`inputs_ids` passed when calling [`AfmoeModel`].
43+
hidden_size (`int`, *optional*, defaults to 2048):
44+
Dimension of the hidden representations.
45+
intermediate_size (`int`, *optional*, defaults to 6144):
46+
Dimension of the dense MLP representations.
47+
moe_intermediate_size (`int`, *optional*, defaults to 1408):
48+
Intermediate size of the routed expert MLPs.
49+
num_hidden_layers (`int`, *optional*, defaults to 32):
50+
Number of hidden layers in the Transformer decoder.
51+
num_dense_layers (`int`, *optional*, defaults to 1):
52+
Number of initial dense layers before MoE layers begin. Layers with index < num_dense_layers will use
53+
standard dense MLPs instead of MoE.
54+
num_attention_heads (`int`, *optional*, defaults to 16):
55+
Number of attention heads for each attention layer in the Transformer decoder.
56+
num_key_value_heads (`int`, *optional*):
57+
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
58+
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
59+
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
60+
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
61+
by meanpooling all the original heads within that group. For more details, check out [this
62+
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
63+
`num_attention_heads`.
64+
head_dim (`int`, *optional*, defaults to 128):
65+
The dimension of each attention head.
66+
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
67+
The non-linear activation function (function or string) in the MLP blocks.
68+
max_position_embeddings (`int`, *optional*, defaults to 16384):
69+
The maximum sequence length that this model might ever be used with.
70+
initializer_range (`float`, *optional*, defaults to 0.02):
71+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
72+
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
73+
The epsilon used by the RMS normalization layers.
74+
use_cache (`bool`, *optional*, defaults to `True`):
75+
Whether or not the model should return the last key/values attentions (not used by all models). Only
76+
relevant if `config.is_decoder=True`.
77+
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
78+
Whether the model's input and output word embeddings should be tied.
79+
rope_theta (`float`, *optional*, defaults to 10000.0):
80+
The base period of the RoPE embeddings.
81+
rope_scaling (`dict`, *optional*):
82+
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
83+
a value for `rope_type` and optionally parameters used for scaling in case you want to use RoPE
84+
with longer `max_position_embeddings`.
85+
num_experts (`int`, *optional*, defaults to 64):
86+
Number of routed experts in MoE layers.
87+
num_experts_per_tok (`int`, *optional*, defaults to 6):
88+
Number of experts to route each token to. This is the top-k value for the token-choice routing.
89+
num_shared_experts (`int`, *optional*, defaults to 2):
90+
Number of shared experts that are always activated for all tokens.
91+
score_func (`str`, *optional*, defaults to `"sigmoid"`):
92+
The scoring function for routing decisions. Can be either "sigmoid" or "softmax".
93+
route_norm (`bool`, *optional*, defaults to `True`):
94+
Whether to normalize routing weights when using sigmoid scoring.
95+
route_scale (`float`, *optional*, defaults to 1.0):
96+
Scaling factor applied to routing weights.
97+
global_attn_every_n_layers (`int`, *optional*, defaults to 4):
98+
The frequency of full attention layers. Every Nth layer will use full attention, while others use sliding
99+
window attention.
100+
sliding_window (`int`, *optional*, defaults to 1024):
101+
Sliding window size for local attention layers.
102+
mup_enabled (`bool`, *optional*, defaults to `False`):
103+
Whether to enable muP (Maximal Update Parametrization) scaling for training stability.
104+
layer_types (`list[str]`, *optional*):
105+
A list that explicitly maps each layer index with its attention type. Each element should be either
106+
"sliding_attention" or "full_attention". If not provided, it will be automatically generated based on
107+
`global_attn_every_n_layers`.
108+
attention_dropout (`float`, *optional*, defaults to 0.0):
109+
The dropout ratio for the attention probabilities.
110+
111+
Example:
112+
```python
113+
>>> from transformers import AfmoeModel, AfmoeConfig
114+
115+
>>> # Initializing an AFMoE configuration
116+
>>> configuration = AfmoeConfig()
117+
118+
>>> # Initializing a model from the afmoe-small-sft-v1 style configuration
119+
>>> model = AfmoeModel(configuration)
120+
121+
>>> # Accessing the model configuration
122+
>>> configuration = model.config
123+
```
124+
"""
125+
126+
model_type = "afmoe"
127+
keys_to_ignore_at_inference = ["past_key_values"]
128+
129+
# Default pipeline parallel plan for base model
130+
base_model_pp_plan = {
131+
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
132+
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
133+
"norm": (["hidden_states"], ["hidden_states"]),
134+
}
135+
136+
def __init__(
137+
self,
138+
vocab_size: Optional[int] = 200192,
139+
hidden_size: Optional[int] = 2048,
140+
intermediate_size: Optional[int] = 6144,
141+
moe_intermediate_size: Optional[int] = 1408,
142+
num_hidden_layers: Optional[int] = 32,
143+
num_dense_layers: Optional[int] = 1,
144+
num_attention_heads: Optional[int] = 16,
145+
num_key_value_heads: Optional[int] = None,
146+
head_dim: Optional[int] = 128,
147+
hidden_act: Optional[str] = "silu",
148+
max_position_embeddings: Optional[int] = 16384,
149+
initializer_range: Optional[float] = 0.02,
150+
rms_norm_eps: Optional[float] = 1e-5,
151+
use_cache: Optional[bool] = True,
152+
tie_word_embeddings: Optional[bool] = False,
153+
rope_theta: Optional[float] = 10000.0,
154+
rope_scaling: Optional[dict] = None,
155+
num_experts: Optional[int] = 64,
156+
num_experts_per_tok: Optional[int] = 6,
157+
num_shared_experts: Optional[int] = 2,
158+
score_func: Optional[str] = "sigmoid",
159+
route_norm: Optional[bool] = True,
160+
route_scale: Optional[float] = 1.0,
161+
global_attn_every_n_layers: Optional[int] = 4,
162+
sliding_window: Optional[int] = 1024,
163+
mup_enabled: Optional[bool] = False,
164+
layer_types: Optional[list] = None,
165+
attention_dropout: Optional[float] = 0.0,
166+
**kwargs,
167+
):
168+
self.vocab_size = vocab_size
169+
self.max_position_embeddings = max_position_embeddings
170+
self.hidden_size = hidden_size
171+
self.intermediate_size = intermediate_size
172+
self.num_hidden_layers = num_hidden_layers
173+
self.num_dense_layers = num_dense_layers
174+
self.num_attention_heads = num_attention_heads
175+
self.head_dim = head_dim
176+
self.hidden_act = hidden_act
177+
self.initializer_range = initializer_range
178+
self.rms_norm_eps = rms_norm_eps
179+
self.use_cache = use_cache
180+
self.rope_theta = rope_theta
181+
self.rope_scaling = rope_scaling
182+
183+
# MoE specific
184+
self.moe_intermediate_size = moe_intermediate_size
185+
self.num_experts_per_tok = num_experts_per_tok
186+
self.num_experts = num_experts
187+
self.num_shared_experts = num_shared_experts
188+
self.score_func = score_func
189+
self.route_norm = route_norm
190+
self.route_scale = route_scale
191+
192+
# Attention specific
193+
self.attention_dropout = attention_dropout
194+
self.global_attn_every_n_layers = global_attn_every_n_layers
195+
self.sliding_window = sliding_window
196+
self.layer_types = layer_types
197+
if self.layer_types is None:
198+
self.layer_types = [
199+
"sliding_attention"
200+
if bool((i + 1) % global_attn_every_n_layers)
201+
else "full_attention"
202+
for i in range(self.num_hidden_layers)
203+
]
204+
layer_type_validation(self.layer_types)
205+
206+
# muP specific
207+
self.mup_enabled = mup_enabled
208+
209+
if num_key_value_heads is None:
210+
num_key_value_heads = num_attention_heads
211+
212+
self.num_key_value_heads = num_key_value_heads
213+
214+
# Setup and validate rope configs
215+
self.rope_parameters = rope_scaling
216+
standardize_rope_params(self, rope_theta=rope_theta)
217+
if self.rope_scaling is not None and "type" in self.rope_scaling:
218+
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
219+
rope_config_validation(self)
220+
221+
super().__init__(
222+
tie_word_embeddings=tie_word_embeddings,
223+
**kwargs,
224+
)
225+
226+
227+
__all__ = ["AfmoeConfig"]
228+

0 commit comments

Comments
 (0)