|
| 1 | +# coding=utf-8 |
| 2 | +# Copyright 2025 Arcee AI and the HuggingFace Inc. team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | +"""AFMoE model configuration""" |
| 16 | + |
| 17 | +from typing import Optional |
| 18 | + |
| 19 | +from ...configuration_utils import PreTrainedConfig, layer_type_validation |
| 20 | +from ...modeling_rope_utils import rope_config_validation, standardize_rope_params |
| 21 | +from ...utils import logging |
| 22 | + |
| 23 | + |
| 24 | +logger = logging.get_logger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class AfmoeConfig(PreTrainedConfig): |
| 28 | + r""" |
| 29 | + This is the configuration class to store the configuration of a [`AfmoeModel`]. It is used to instantiate an |
| 30 | + AFMoE model according to the specified arguments, defining the model architecture. Instantiating a configuration |
| 31 | + with the defaults will yield a similar configuration to that of [arcee-ai/Trinity-Mini](https://huggingface.co/arcee-ai/Trinity-Mini). |
| 32 | +
|
| 33 | + AFMoE is an Adaptive Feedforward MoE (Mixture of Experts) model with token-choice routing, shared experts, and a |
| 34 | + hybrid attention mechanism combining sliding window and full attention patterns. |
| 35 | +
|
| 36 | + Configuration objects inherit from [`PreTrainedConfig`] and can be used to control the model outputs. Read the |
| 37 | + documentation from [`PreTrainedConfig`] for more information. |
| 38 | +
|
| 39 | + Args: |
| 40 | + vocab_size (`int`, *optional*, defaults to 200192): |
| 41 | + Vocabulary size of the AFMoE model. Defines the number of different tokens that can be represented by the |
| 42 | + `inputs_ids` passed when calling [`AfmoeModel`]. |
| 43 | + hidden_size (`int`, *optional*, defaults to 2048): |
| 44 | + Dimension of the hidden representations. |
| 45 | + intermediate_size (`int`, *optional*, defaults to 6144): |
| 46 | + Dimension of the dense MLP representations. |
| 47 | + moe_intermediate_size (`int`, *optional*, defaults to 1408): |
| 48 | + Intermediate size of the routed expert MLPs. |
| 49 | + num_hidden_layers (`int`, *optional*, defaults to 32): |
| 50 | + Number of hidden layers in the Transformer decoder. |
| 51 | + num_dense_layers (`int`, *optional*, defaults to 1): |
| 52 | + Number of initial dense layers before MoE layers begin. Layers with index < num_dense_layers will use |
| 53 | + standard dense MLPs instead of MoE. |
| 54 | + num_attention_heads (`int`, *optional*, defaults to 16): |
| 55 | + Number of attention heads for each attention layer in the Transformer decoder. |
| 56 | + num_key_value_heads (`int`, *optional*): |
| 57 | + This is the number of key_value heads that should be used to implement Grouped Query Attention. If |
| 58 | + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if |
| 59 | + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When |
| 60 | + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed |
| 61 | + by meanpooling all the original heads within that group. For more details, check out [this |
| 62 | + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to |
| 63 | + `num_attention_heads`. |
| 64 | + head_dim (`int`, *optional*, defaults to 128): |
| 65 | + The dimension of each attention head. |
| 66 | + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): |
| 67 | + The non-linear activation function (function or string) in the MLP blocks. |
| 68 | + max_position_embeddings (`int`, *optional*, defaults to 16384): |
| 69 | + The maximum sequence length that this model might ever be used with. |
| 70 | + initializer_range (`float`, *optional*, defaults to 0.02): |
| 71 | + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. |
| 72 | + rms_norm_eps (`float`, *optional*, defaults to 1e-05): |
| 73 | + The epsilon used by the RMS normalization layers. |
| 74 | + use_cache (`bool`, *optional*, defaults to `True`): |
| 75 | + Whether or not the model should return the last key/values attentions (not used by all models). Only |
| 76 | + relevant if `config.is_decoder=True`. |
| 77 | + tie_word_embeddings (`bool`, *optional*, defaults to `False`): |
| 78 | + Whether the model's input and output word embeddings should be tied. |
| 79 | + rope_theta (`float`, *optional*, defaults to 10000.0): |
| 80 | + The base period of the RoPE embeddings. |
| 81 | + rope_scaling (`dict`, *optional*): |
| 82 | + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain |
| 83 | + a value for `rope_type` and optionally parameters used for scaling in case you want to use RoPE |
| 84 | + with longer `max_position_embeddings`. |
| 85 | + num_experts (`int`, *optional*, defaults to 64): |
| 86 | + Number of routed experts in MoE layers. |
| 87 | + num_experts_per_tok (`int`, *optional*, defaults to 6): |
| 88 | + Number of experts to route each token to. This is the top-k value for the token-choice routing. |
| 89 | + num_shared_experts (`int`, *optional*, defaults to 2): |
| 90 | + Number of shared experts that are always activated for all tokens. |
| 91 | + score_func (`str`, *optional*, defaults to `"sigmoid"`): |
| 92 | + The scoring function for routing decisions. Can be either "sigmoid" or "softmax". |
| 93 | + route_norm (`bool`, *optional*, defaults to `True`): |
| 94 | + Whether to normalize routing weights when using sigmoid scoring. |
| 95 | + route_scale (`float`, *optional*, defaults to 1.0): |
| 96 | + Scaling factor applied to routing weights. |
| 97 | + global_attn_every_n_layers (`int`, *optional*, defaults to 4): |
| 98 | + The frequency of full attention layers. Every Nth layer will use full attention, while others use sliding |
| 99 | + window attention. |
| 100 | + sliding_window (`int`, *optional*, defaults to 1024): |
| 101 | + Sliding window size for local attention layers. |
| 102 | + mup_enabled (`bool`, *optional*, defaults to `False`): |
| 103 | + Whether to enable muP (Maximal Update Parametrization) scaling for training stability. |
| 104 | + layer_types (`list[str]`, *optional*): |
| 105 | + A list that explicitly maps each layer index with its attention type. Each element should be either |
| 106 | + "sliding_attention" or "full_attention". If not provided, it will be automatically generated based on |
| 107 | + `global_attn_every_n_layers`. |
| 108 | + attention_dropout (`float`, *optional*, defaults to 0.0): |
| 109 | + The dropout ratio for the attention probabilities. |
| 110 | +
|
| 111 | + Example: |
| 112 | + ```python |
| 113 | + >>> from transformers import AfmoeModel, AfmoeConfig |
| 114 | +
|
| 115 | + >>> # Initializing an AFMoE configuration |
| 116 | + >>> configuration = AfmoeConfig() |
| 117 | +
|
| 118 | + >>> # Initializing a model from the afmoe-small-sft-v1 style configuration |
| 119 | + >>> model = AfmoeModel(configuration) |
| 120 | +
|
| 121 | + >>> # Accessing the model configuration |
| 122 | + >>> configuration = model.config |
| 123 | + ``` |
| 124 | + """ |
| 125 | + |
| 126 | + model_type = "afmoe" |
| 127 | + keys_to_ignore_at_inference = ["past_key_values"] |
| 128 | + |
| 129 | + # Default pipeline parallel plan for base model |
| 130 | + base_model_pp_plan = { |
| 131 | + "embed_tokens": (["input_ids"], ["inputs_embeds"]), |
| 132 | + "layers": (["hidden_states", "attention_mask"], ["hidden_states"]), |
| 133 | + "norm": (["hidden_states"], ["hidden_states"]), |
| 134 | + } |
| 135 | + |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + vocab_size: Optional[int] = 200192, |
| 139 | + hidden_size: Optional[int] = 2048, |
| 140 | + intermediate_size: Optional[int] = 6144, |
| 141 | + moe_intermediate_size: Optional[int] = 1408, |
| 142 | + num_hidden_layers: Optional[int] = 32, |
| 143 | + num_dense_layers: Optional[int] = 1, |
| 144 | + num_attention_heads: Optional[int] = 16, |
| 145 | + num_key_value_heads: Optional[int] = None, |
| 146 | + head_dim: Optional[int] = 128, |
| 147 | + hidden_act: Optional[str] = "silu", |
| 148 | + max_position_embeddings: Optional[int] = 16384, |
| 149 | + initializer_range: Optional[float] = 0.02, |
| 150 | + rms_norm_eps: Optional[float] = 1e-5, |
| 151 | + use_cache: Optional[bool] = True, |
| 152 | + tie_word_embeddings: Optional[bool] = False, |
| 153 | + rope_theta: Optional[float] = 10000.0, |
| 154 | + rope_scaling: Optional[dict] = None, |
| 155 | + num_experts: Optional[int] = 64, |
| 156 | + num_experts_per_tok: Optional[int] = 6, |
| 157 | + num_shared_experts: Optional[int] = 2, |
| 158 | + score_func: Optional[str] = "sigmoid", |
| 159 | + route_norm: Optional[bool] = True, |
| 160 | + route_scale: Optional[float] = 1.0, |
| 161 | + global_attn_every_n_layers: Optional[int] = 4, |
| 162 | + sliding_window: Optional[int] = 1024, |
| 163 | + mup_enabled: Optional[bool] = False, |
| 164 | + layer_types: Optional[list] = None, |
| 165 | + attention_dropout: Optional[float] = 0.0, |
| 166 | + **kwargs, |
| 167 | + ): |
| 168 | + self.vocab_size = vocab_size |
| 169 | + self.max_position_embeddings = max_position_embeddings |
| 170 | + self.hidden_size = hidden_size |
| 171 | + self.intermediate_size = intermediate_size |
| 172 | + self.num_hidden_layers = num_hidden_layers |
| 173 | + self.num_dense_layers = num_dense_layers |
| 174 | + self.num_attention_heads = num_attention_heads |
| 175 | + self.head_dim = head_dim |
| 176 | + self.hidden_act = hidden_act |
| 177 | + self.initializer_range = initializer_range |
| 178 | + self.rms_norm_eps = rms_norm_eps |
| 179 | + self.use_cache = use_cache |
| 180 | + self.rope_theta = rope_theta |
| 181 | + self.rope_scaling = rope_scaling |
| 182 | + |
| 183 | + # MoE specific |
| 184 | + self.moe_intermediate_size = moe_intermediate_size |
| 185 | + self.num_experts_per_tok = num_experts_per_tok |
| 186 | + self.num_experts = num_experts |
| 187 | + self.num_shared_experts = num_shared_experts |
| 188 | + self.score_func = score_func |
| 189 | + self.route_norm = route_norm |
| 190 | + self.route_scale = route_scale |
| 191 | + |
| 192 | + # Attention specific |
| 193 | + self.attention_dropout = attention_dropout |
| 194 | + self.global_attn_every_n_layers = global_attn_every_n_layers |
| 195 | + self.sliding_window = sliding_window |
| 196 | + self.layer_types = layer_types |
| 197 | + if self.layer_types is None: |
| 198 | + self.layer_types = [ |
| 199 | + "sliding_attention" |
| 200 | + if bool((i + 1) % global_attn_every_n_layers) |
| 201 | + else "full_attention" |
| 202 | + for i in range(self.num_hidden_layers) |
| 203 | + ] |
| 204 | + layer_type_validation(self.layer_types) |
| 205 | + |
| 206 | + # muP specific |
| 207 | + self.mup_enabled = mup_enabled |
| 208 | + |
| 209 | + if num_key_value_heads is None: |
| 210 | + num_key_value_heads = num_attention_heads |
| 211 | + |
| 212 | + self.num_key_value_heads = num_key_value_heads |
| 213 | + |
| 214 | + # Setup and validate rope configs |
| 215 | + self.rope_parameters = rope_scaling |
| 216 | + standardize_rope_params(self, rope_theta=rope_theta) |
| 217 | + if self.rope_scaling is not None and "type" in self.rope_scaling: |
| 218 | + self.rope_scaling["rope_type"] = self.rope_scaling["type"] |
| 219 | + rope_config_validation(self) |
| 220 | + |
| 221 | + super().__init__( |
| 222 | + tie_word_embeddings=tie_word_embeddings, |
| 223 | + **kwargs, |
| 224 | + ) |
| 225 | + |
| 226 | + |
| 227 | +__all__ = ["AfmoeConfig"] |
| 228 | + |
0 commit comments