Conversation
6b08d17 to
e3ad5e9
Compare
e3ad5e9 to
3a4280c
Compare
| def rotate_half(x: torch.Tensor) -> torch.Tensor: | ||
| x1 = x[..., : x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2 :] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
|
|
||
| def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | ||
| batch, num_key_value_heads, slen, head_dim = hidden_states.shape | ||
| if n_rep == 1: | ||
| return hidden_states | ||
| hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | ||
| return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | ||
|
|
||
|
|
||
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids: Optional[torch.Tensor] = None, unsqueeze_dim: int = 1): | ||
| cos = cos.unsqueeze(unsqueeze_dim) | ||
| sin = sin.unsqueeze(unsqueeze_dim) | ||
| q_embed = (q * cos) + (rotate_half(q) * sin) | ||
| k_embed = (k * cos) + (rotate_half(k) * sin) | ||
| return q_embed, k_embed | ||
|
|
||
|
|
||
| def eager_attention_forward( | ||
| module: nn.Module, | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| attention_mask: Optional[torch.Tensor], | ||
| scaling: float, | ||
| dropout: float = 0.0, | ||
| **kwargs, | ||
| ): | ||
| key_states = repeat_kv(key, module.num_key_value_groups) | ||
| value_states = repeat_kv(value, module.num_key_value_groups) | ||
|
|
||
| attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling | ||
| if attention_mask is not None: | ||
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | ||
| attn_weights = attn_weights + causal_mask | ||
|
|
||
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) | ||
| attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) | ||
| attn_output = torch.matmul(attn_weights, value_states) | ||
| attn_output = attn_output.transpose(1, 2).contiguous() | ||
|
|
||
| return attn_output, attn_weights |
There was a problem hiding this comment.
these can also be imported from Llama! 😉
| top_scores, selected_experts = self.router(hidden_states, self.expert_bias) | ||
|
|
||
| # Process through shared experts | ||
| if self.shared_experts is not None: |
There was a problem hiding this comment.
same comment, is this used by the released model or not?
There was a problem hiding this comment.
the first layer is a standard dense FFN, and all subsequent layers use the MoE block
There was a problem hiding this comment.
In that case the arch should be different! Use normal MLP for mlp and expert for expert! 🤗
You can set the first layer then just do += we want to avoid codepathes as much as possible
| # MoE or dense FFN | ||
| self.moe_enabled = layer_idx >= config.num_dense_layers | ||
| if self.moe_enabled: | ||
| self.mlp = AfmoeMoE(config) | ||
| else: | ||
| self.mlp = AfmoeMLP(config) |
There was a problem hiding this comment.
is moe disabled on any of the released ckpts? 🤗
bcd7b97 to
8958684
Compare
e4aa76e to
ecb7438
Compare
8c6bdb4 to
045776d
Compare
| This mirrors the Experts pattern used across other MoE models to ease checkpoint conversion. | ||
| """ | ||
|
|
||
| _checkpoint_conversion_mapping = {"experts": "experts"} |
There was a problem hiding this comment.
| _checkpoint_conversion_mapping = {"experts": "experts"} |
| top_scores, selected_experts = self.router(hidden_states, self.expert_bias) | ||
|
|
||
| # Process through shared experts | ||
| if self.shared_experts is not None: |
| key_states = key_states.transpose(1, 2) | ||
| value_states = value_states.transpose(1, 2) | ||
|
|
||
| if self.is_local_attention: |
There was a problem hiding this comment.
i did not get an answer
| # MoE or dense FFN | ||
| self.moe_enabled = layer_idx >= config.num_dense_layers | ||
| if self.moe_enabled: | ||
| self.mlp = AfmoeMoE(config) | ||
| else: | ||
| self.mlp = AfmoeMLP(config) |
tests_output.txt
Outdated
8d78a29 to
02640f4
Compare
ArthurZucker
left a comment
There was a problem hiding this comment.
Thanks, we tend to try and remove code path as much as possible, if not done here we 'll do it post release !
| _, selected_experts = torch.topk(scores + expert_bias, k=self.top_k, dim=1) | ||
| top_scores = scores.gather(dim=1, index=selected_experts) | ||
|
|
||
| if self.route_norm: |
There was a problem hiding this comment.
is this always True or False? (cf removing code path :)
| return top_scores, selected_experts | ||
|
|
||
|
|
||
| class AfmoeExperts(nn.ModuleList): |
There was a problem hiding this comment.
you could just inherti from Mixtral or Qwen2Moe it should be the same no?
There was a problem hiding this comment.
The checkpoint weight structure is different in AFMoE
There was a problem hiding this comment.
We have an online weight converter but now worries :)
| if isinstance(module, nn.Linear): | ||
| module.weight.normal_(mean=0.0, std=self.config.initializer_range) | ||
| if module.bias is not None: | ||
| module.bias.zero_() | ||
| elif isinstance(module, nn.Embedding): | ||
| module.weight.normal_(mean=0.0, std=self.config.initializer_range) | ||
| if module.padding_idx is not None: | ||
| module.weight[module.padding_idx].zero_() | ||
| elif isinstance(module, AfmoeRMSNorm): | ||
| module.weight.fill_(1.0) |
There was a problem hiding this comment.
these should not be used, can you use nn.init instead please ! one of the ci will fail as we require this for inits!
02640f4 to
30c3a20
Compare
|
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: afmoe, auto |
|
It seems like this model wasn't added to |
|
Seems like it - how did the CI pass? |
* Add AFMoE model support * Address review feedback for AFMoE implementation * Add flex attention support to AFMoE model * Fix expert_bias routing in AFMoE * Remove test-results directory * Address PR review feedback for AFMoE model * fix(afmoe): ensure RMSNorm output dtype matches input dtype) * properly return attn weights * fix most tests * cleanup Remove shared expert if else as defaults to 2 Remove `route_norm` as it default to `True`. Make test smaller faster * fix input embeds api * update rope API, smaller test and should be good to go * oups wront place to skip unittest * quality * update * rope parameter docstring fill --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Arthur <[email protected]>
* Add AFMoE model support * Address review feedback for AFMoE implementation * Add flex attention support to AFMoE model * Fix expert_bias routing in AFMoE * Remove test-results directory * Address PR review feedback for AFMoE model * fix(afmoe): ensure RMSNorm output dtype matches input dtype) * properly return attn weights * fix most tests * cleanup Remove shared expert if else as defaults to 2 Remove `route_norm` as it default to `True`. Make test smaller faster * fix input embeds api * update rope API, smaller test and should be good to go * oups wront place to skip unittest * quality * update * rope parameter docstring fill --------- Co-authored-by: Arthur <[email protected]> Co-authored-by: Arthur <[email protected]>
Summary
This PR adds support for the AFMoE (Arcee Foundational Mixture of Experts) model architecture for the upcoming Trinity-Mini and Trinity-Nano releases. AFMoE is a decoder-only transformer model featuring a sparse Mixture of Experts (MoE) approach, combining token-choice routing with shared experts and several architectural innovations for efficient inference and improved performance.
Model Description
AFMoE features the following key architectural components:
Mixture of Experts with Shared Experts: Combines routed experts (activated per-token via learned routing) with always-active shared experts for stable base computation
Token-Choice Routing: Uses sigmoid or softmax-based routing with normalization and scaling for expert selection
Q/K Normalization and Gating: Applies RMSNorm to query and key projections and uses sigmoid gating on attention outputs for improved training stability
Hybrid Attention Patterns: Alternates between sliding window attention and full attention across layers for efficiency with long contexts
Dual Normalization: Uses pre- and post-normalization around both attention and MLP blocks for training stability
Configurable Dense Layers: Allows initial layers to use dense MLPs before transitioning to sparse MoE layers (
num_dense_layers)Implementation Details
Modular implementation leveraging transformers' modular architecture:
Efficient
AfmoeRMSNormfor layer normalizationAfmoeRotaryEmbeddingfor positional encodingAfmoeAttentionclass implementing Q/K normalization and output gatingAfmoeTokenChoiceRouterfor expert selectionAfmoeMoEclass implementing shared + routed experts architectureAfmoeDecoderLayerintegrating attention and MoE blocks with dual normalizationTesting
arcee-ai/Trinity-MiniDocumentation
docs/source/en/model_doc/afmoe.md