Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@
title: Main Classes
- sections:
- sections:
- local: model_doc/afmoe
title: AFMoE
- local: model_doc/albert
title: ALBERT
- local: model_doc/apertus
Expand Down
129 changes: 129 additions & 0 deletions docs/source/en/model_doc/afmoe.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
<!--Copyright 2025 Arcee AI and The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

<div style="float: right;">
<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
<img alt="FlashAttention" src="https://img.shields.io/badge/%E2%9A%A1%EF%B8%8E%20FlashAttention-eae0c8?style=flat">
<img alt="SDPA" src="https://img.shields.io/badge/SDPA-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>
</div>

# AFMoE

AFMoE (Arcee Foundational Mixture of Experts) is a decoder-only transformer model that extends the Llama architecture with a sparse Mixture of Experts (MoE) approach. The model combines token-choice routing with shared experts and employs several architectural innovations for efficient inference and improved performance.

## Key Architecture Features

AFMoE introduces several key modifications to the standard transformer architecture:

- **Mixture of Experts with Shared Experts**: Combines routed experts (activated per-token via learned routing) with always-active shared experts for stable base computation
- **Token-Choice Routing**: Uses sigmoid or softmax-based routing with normalization and scaling for expert selection
- **Q/K Normalization and Gating**: Applies RMSNorm to query and key projections and uses sigmoid gating on attention outputs for improved stability
- **Hybrid Attention Patterns**: Alternates between sliding window attention and full attention across layers for efficiency with long contexts
- **Dual Normalization**: Uses pre- and post-normalization around both attention and MLP blocks for training stability
- **Configurable Dense Layers**: Allows initial layers to use dense MLPs before transitioning to sparse MoE layers

The model supports extended context lengths with RoPE embeddings and includes all standard Transformers features including Flash Attention 2, SDPA, gradient checkpointing, and quantization support.

> [!TIP]
> AFMoE is particularly well-suited for scenarios requiring efficient scaling through sparsity while maintaining strong performance. The shared experts provide a stable computation baseline while routed experts enable model capacity scaling.

The example below demonstrates how to generate text with AFMoE using [`Pipeline`] or the [`AutoModel`].

<hfoptions id="usage">
<hfoption id="Pipeline">

```py
import torch
from transformers import pipeline

pipeline = pipeline(
task="text-generation",
model="arcee-ai/Trinity-Mini",
torch_dtype=torch.bfloat16,
device=0
)

output = pipeline("The key innovation in mixture of experts is")
print(output[0]["generated_text"])
```

</hfoption>
<hfoption id="AutoModel">

```py
import torch
from transformers import AutoTokenizer, AfmoeForCausalLM

tokenizer = AutoTokenizer.from_pretrained("arcee-ai/Trinity-Mini")
model = AfmoeForCausalLM.from_pretrained(
"arcee-ai/Trinity-Mini",
torch_dtype=torch.bfloat16,
device_map="auto"
)

inputs = tokenizer("The key innovation in mixture of experts is", return_tensors="pt")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=50)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```

</hfoption>
</hfoptions>

## Model Architecture Details

### Expert Routing

AFMoE uses token-choice routing where each token independently selects top-k experts based on router logits. The routing mechanism includes:

- Configurable scoring function (sigmoid or softmax)
- Optional route normalization for balanced expert utilization
- Route scaling to control expert contribution strength
- Bias correction for expert selection

### Shared Experts

Unlike standard MoE models, AFMoE includes shared experts that are always activated for every token, providing:

- A stable computation baseline across all tokens
- Reduced variance in model outputs
- Better handling of out-of-distribution inputs

### Attention Mechanism

The hybrid attention pattern alternates between:

- **Sliding Window Attention**: For efficiency on long sequences, with configurable window size
- **Full Attention**: Applied every N layers (configurable via `global_attn_every_n_layers`) for global context

All attention layers include Q/K normalization and output gating for improved training dynamics.

## AfmoeConfig

[[autodoc]] AfmoeConfig

## AfmoeModel

[[autodoc]] AfmoeModel
- forward

## AfmoeForCausalLM

[[autodoc]] AfmoeForCausalLM
- forward
1 change: 1 addition & 0 deletions splitted_tests.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tests/models/afmoe/test_modeling_afmoe.py
27 changes: 27 additions & 0 deletions src/transformers/models/afmoe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_afmoe import *
from .modeling_afmoe import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
225 changes: 225 additions & 0 deletions src/transformers/models/afmoe/configuration_afmoe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
# coding=utf-8
# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""AFMoE model configuration"""

from typing import Optional

from ...configuration_utils import PreTrainedConfig, layer_type_validation
from ...modeling_rope_utils import rope_config_validation, standardize_rope_params
from ...utils import logging


logger = logging.get_logger(__name__)


class AfmoeConfig(PreTrainedConfig):
r"""
This is the configuration class to store the configuration of a [`AfmoeModel`]. It is used to instantiate an
AFMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration
with the defaults will yield a similar configuration to that of [arcee-ai/Trinity-Mini](https://huggingface.co/arcee-ai/Trinity-Mini).

AFMoE is an Adaptive Feedforward MoE (Mixture of Experts) model with token-choice routing, shared experts, and a
hybrid attention mechanism combining sliding window and full attention patterns.

Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PreTrainedConfig`] for more information.

Args:
vocab_size (`int`, *optional*, defaults to 200192):
Vocabulary size of the AFMoE model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`AfmoeModel`].
hidden_size (`int`, *optional*, defaults to 2048):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 6144):
Dimension of the dense MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1408):
Intermediate size of the routed expert MLPs.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_dense_layers (`int`, *optional*, defaults to 1):
Number of initial dense layers before MoE layers begin. Layers with index < num_dense_layers will use
standard dense MLPs instead of MoE.
num_attention_heads (`int`, *optional*, defaults to 16):
Number of attention heads for each attention layer in the Transformer decoder.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details, check out [this
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
`num_attention_heads`.
head_dim (`int`, *optional*, defaults to 128):
The dimension of each attention head.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the MLP blocks.
max_position_embeddings (`int`, *optional*, defaults to 16384):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the RMS normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether the model's input and output word embeddings should be tied.
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`dict`, *optional*):
Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain
a value for `rope_type` and optionally parameters used for scaling in case you want to use RoPE
with longer `max_position_embeddings`.
num_experts (`int`, *optional*, defaults to 64):
Number of routed experts in MoE layers.
num_experts_per_tok (`int`, *optional*, defaults to 6):
Number of experts to route each token to. This is the top-k value for the token-choice routing.
num_shared_experts (`int`, *optional*, defaults to 2):
Number of shared experts that are always activated for all tokens.
score_func (`str`, *optional*, defaults to `"sigmoid"`):
The scoring function for routing decisions. Can be either "sigmoid" or "softmax".
route_norm (`bool`, *optional*, defaults to `True`):
Whether to normalize routing weights when using sigmoid scoring.
route_scale (`float`, *optional*, defaults to 1.0):
Scaling factor applied to routing weights.
global_attn_every_n_layers (`int`, *optional*, defaults to 4):
The frequency of full attention layers. Every Nth layer will use full attention, while others use sliding
window attention.
sliding_window (`int`, *optional*, defaults to 1024):
Sliding window size for local attention layers.
mup_enabled (`bool`, *optional*, defaults to `False`):
Whether to enable muP (Maximal Update Parametrization) scaling for training stability.
layer_types (`list[str]`, *optional*):
A list that explicitly maps each layer index with its attention type. Each element should be either
"sliding_attention" or "full_attention". If not provided, it will be automatically generated based on
`global_attn_every_n_layers`.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.

Example:
```python
>>> from transformers import AfmoeModel, AfmoeConfig

>>> # Initializing an AFMoE configuration
>>> configuration = AfmoeConfig()

>>> # Initializing a model from the afmoe-small-sft-v1 style configuration
>>> model = AfmoeModel(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```
"""

model_type = "afmoe"
keys_to_ignore_at_inference = ["past_key_values"]

# Default pipeline parallel plan for base model
base_model_pp_plan = {
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
"norm": (["hidden_states"], ["hidden_states"]),
}

def __init__(
self,
vocab_size: Optional[int] = 200192,
hidden_size: Optional[int] = 2048,
intermediate_size: Optional[int] = 6144,
moe_intermediate_size: Optional[int] = 1408,
num_hidden_layers: Optional[int] = 32,
num_dense_layers: Optional[int] = 1,
num_attention_heads: Optional[int] = 16,
num_key_value_heads: Optional[int] = None,
head_dim: Optional[int] = 128,
hidden_act: Optional[str] = "silu",
max_position_embeddings: Optional[int] = 16384,
initializer_range: Optional[float] = 0.02,
rms_norm_eps: Optional[float] = 1e-5,
use_cache: Optional[bool] = True,
tie_word_embeddings: Optional[bool] = False,
rope_theta: Optional[float] = 10000.0,
rope_scaling: Optional[dict] = None,
num_experts: Optional[int] = 64,
num_experts_per_tok: Optional[int] = 6,
num_shared_experts: Optional[int] = 2,
score_func: Optional[str] = "sigmoid",
route_norm: Optional[bool] = True,
route_scale: Optional[float] = 1.0,
global_attn_every_n_layers: Optional[int] = 4,
sliding_window: Optional[int] = 1024,
mup_enabled: Optional[bool] = False,
layer_types: Optional[list] = None,
attention_dropout: Optional[float] = 0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_dense_layers = num_dense_layers
self.num_attention_heads = num_attention_heads
self.head_dim = head_dim
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling

# MoE specific
self.moe_intermediate_size = moe_intermediate_size
self.num_experts_per_tok = num_experts_per_tok
self.num_experts = num_experts
self.num_shared_experts = num_shared_experts
self.score_func = score_func
self.route_norm = route_norm
self.route_scale = route_scale

# Attention specific
self.attention_dropout = attention_dropout
self.global_attn_every_n_layers = global_attn_every_n_layers
self.sliding_window = sliding_window
self.layer_types = layer_types
if self.layer_types is None:
self.layer_types = [
"sliding_attention" if bool((i + 1) % global_attn_every_n_layers) else "full_attention"
for i in range(self.num_hidden_layers)
]
layer_type_validation(self.layer_types)

# muP specific
self.mup_enabled = mup_enabled

if num_key_value_heads is None:
num_key_value_heads = num_attention_heads

self.num_key_value_heads = num_key_value_heads

# Setup and validate rope configs
self.rope_parameters = rope_scaling
standardize_rope_params(self, rope_theta=rope_theta)
if self.rope_scaling is not None and "type" in self.rope_scaling:
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
rope_config_validation(self)

super().__init__(
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)


__all__ = ["AfmoeConfig"]
Loading