Skip to content

[V1] [ROCm] Enable EP with AITER Fused MoE #20270

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 1, 2025

Conversation

tjtanaa
Copy link
Contributor

@tjtanaa tjtanaa commented Jun 30, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

Enable AITER Fused MoE expert parallelism feature.

Test Plan

Perform lm_eval on gsm8k_cot dataset.

Example test command

HF_HUB_OFFLINE=1 \
VLLM_USE_V1=1 \
VLLM_ROCM_USE_AITER=1 \
VLLM_ROCM_USE_AITER_LINEAR=0 \
VLLM_USE_TRITON_FLASH_ATTN=0 \
VLLM_ROCM_USE_AITER_RMSNORM=0 \
VLLM_ROCM_USE_AITER_MHA=0 \
VLLM_ROCM_USE_AITER_MOE=1 \
VLLM_ROCM_USE_AITER_PAGED_ATTN=0 \
VLLM_V1_USE_PREFILL_DECODE_ATTENTION=1 \
SAFETENSORS_FAST_GPU=1 \
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,max_model_len=10000 --trust_remote_code --tasks gsm8k_cot --num_fewshot 8 --batch_size 250 --seed 1234

Test Result

  1. deepseek-ai/DeepSeek-R1 V0 Engine TP8 EP8

vllm (pretrained=deepseek-ai/DeepSeek-R1,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,block_size=1,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.9303 ± 0.0070
strict-match 8 exact_match 0.9280 ± 0.0071
  1. deepseek-ai/DeepSeek-R1 V1 Engine TP8 EP8

vllm (pretrained=deepseek-ai/DeepSeek-R1,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,block_size=1,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.9386 ± 0.0066
strict-match 8 exact_match 0.9272 ± 0.0072
  1. Qwen/Qwen3-235B-A22B-FP8 V0 Engine TP8 EP8

vllm (pretrained=Qwen/Qwen3-235B-A22B-FP8,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.8992 ± 0.0083
strict-match 8 exact_match 0.8332 ± 0.0103
  1. Qwen/Qwen3-235B-A22B-FP8 V1 Engine TP8 EP8

vllm (pretrained=Qwen/Qwen3-235B-A22B-FP8,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.9098 ± 0.0079
strict-match 8 exact_match 0.8302 ± 0.0103
  1. RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic V1 Engine TP8 EP8 (Without AITER)

vllm (pretrained=RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.7369 ± 0.0121
strict-match 8 exact_match 0.9325 ± 0.0069
  1. RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic V1 TP8 EP8 (AITER tkw1)

vllm (pretrained=RedHatAI/Llama-4-Scout-17B-16E-Instruct-FP8-dynamic,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.7854 ± 0.0113
strict-match 8 exact_match 0.9333 ± 0.0069
  1. meta-llama/Llama-4-Scout-17B-16E-Instruct V1 Engine TP8 EP8 (Without AITER)

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.7741 ± 0.0115
strict-match 8 exact_match 0.9325 ± 0.0069
  1. meta-llama/Llama-4-Scout-17B-16E-Instruct V1 Engine TP8 EP8 (AITER Fused MoE)

vllm (pretrained=meta-llama/Llama-4-Scout-17B-16E-Instruct,tensor_parallel_size=8,enable_expert_parallel=True,add_bos_token=True,max_model_len=10000,trust_remote_code=True), gen_kwargs: (None), limit: None, num_fewshot: 8, batch_size: 250

Tasks Version Filter n-shot Metric Value Stderr
gsm8k_cot 3 flexible-extract 8 exact_match 0.7597 ± 0.0118
strict-match 8 exact_match 0.9295 ± 0.0071

Performance Gain

Dataset: Random
ISL:OSL= 1000:1000
Number of prompts: 500
Concurrency: 64
vLLM Engine Version: V1 Engine
Tensor Parallelism: 8
Expert Parallelism: 8

Comparison Scenarios: AITER Fused MoE only vs No AITER

  • Notes on DeepSeek-R1 benchmark: in V1 Engine, There is only AITER MLA backend, so the DeepSeek-R1 benchmark is AITER MLA + AITER Fused MoE (with aiter.biased_grouped_topk kernel) vs AITER MLA

  • Notes on Llama4 models benchmark. The flags for the other models:
    VLLM_ROCM_USE_AITER_LINEAR=0 VLLM_ROCM_USE_AITER_RMSNORM=0 VLLM_ROCM_USE_AITER_MHA=0

DeepSeek-R1

Metric No AITER AITER Gain
Request throughput (req/s) 1.56 2.90 +85.9%
Output token throughput (tok/s) 486.26 891.14 +83.3%
Total token throughput (tok/s) 2040.52 3784.11 +85.4%
Mean TTFT (ms) 2432.53 1572.64 -35.3% (better)
Mean TPOT (ms) 542.99 352.58 -35.1% (better)
Mean ITL (ms) 105.39 56.84 -46.1% (better)

Llama-4-Scout-17B-16E-Instruct

Metric No AITER AITER Gain
Request throughput (req/s) 3.57 5.56 +55.7%
Output token throughput (tok/s) 1019.63 1605.34 +57.4%
Total token throughput (tok/s) 4582.81 7153.24 +56.1%
Mean TTFT (ms) 536.00 323.43 -39.7% (better)
Mean TPOT (ms) 51.91 33.55 -35.4% (better)
Mean ITL (ms) 50.50 32.56 -35.5% (better)

Llama-4-Maverick-17B-128E-Instruct-FP8

Metric No AITER AITER Gain
Request throughput (req/s) 3.54 5.62 +58.8%
Output token throughput (tok/s) 1026.57 1587.21 +54.6%
Total token throughput (tok/s) 4563.25 7198.01 +57.7%
Mean TTFT (ms) 539.02 331.61 -38.5% (better)
Mean TPOT (ms) 51.32 34.13 -33.5% (better)
Mean ITL (ms) 50.05 32.85 -34.4% (better)

(Optional) Documentation Update

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the rocm Related to AMD ROCm label Jun 30, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @tjtanaa, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables the AITER Fused MoE expert parallelism feature for ROCm by introducing and utilizing an expert mask. The changes involve passing the expert map through the layers and creating a mask based on its values to indicate valid experts during computation. This enhancement is crucial for models employing expert parallelism, particularly in scenarios where certain experts might be invalid or masked out.

Highlights

  • Expert Masking: Introduced an expert_mask parameter to the rocm_aiter_fused_experts function in vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py. This mask is used to indicate valid experts.
  • Expert Map Propagation: Propagated the expert_map parameter from the higher-level apply functions in vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py and vllm/model_executor/layers/quantization/fp8.py to the rocm_aiter_fused_experts function.
  • Conditional Expert Mask Creation: Added logic to create the expert_mask based on the expert_map in vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py. If expert_map is provided, the mask is created; otherwise, it remains None.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist is currently in preview and may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request successfully enables the AITER Fused MoE expert parallelism feature for ROCm. The changes correctly introduce and propagate the expert_map parameter, which is then used to derive an expert_mask for the underlying kernels. The modifications are well-contained and appear to be functionally correct and consistent across the affected files. No critical, high, or medium severity issues were identified in the provided diffs.

@tjtanaa tjtanaa marked this pull request as ready for review June 30, 2025 17:27
Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a huge perf improvement, nice! Thanks for the detailed comparisons. BTW do we have a kernel test for aiter fused moe that we could add a case for EP to?

@tjtanaa
Copy link
Contributor Author

tjtanaa commented Jul 1, 2025

@mgoin Since ROCm/aiter repo is a repo for kernels, maybe it is better to keep the kernel level unit test within ROCm/aiter repo?

@DarkLight1337 DarkLight1337 enabled auto-merge (squash) July 1, 2025 13:03
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 1, 2025
@DarkLight1337 DarkLight1337 merged commit 02cabff into vllm-project:main Jul 1, 2025
96 checks passed
CSWYF3634076 pushed a commit to CSWYF3634076/vllm that referenced this pull request Jul 2, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants