Skip to content

Conversation

@jinzhen-lin
Copy link
Contributor

@jinzhen-lin jinzhen-lin commented Nov 28, 2025

Introduction

This PR optimize moe_align_block_size from three aspects:

  1. For a small batch size, use the smallest possible value for max_num_tokens_padded.

  2. In the CUDA Kernel, use additional thread or threadblock resources to fill sorted_token_ids. The previous CUDA Kernel used very few computational resources, and the filling of sorted_token_ids and the counting of experts were performed sequentially. Since sorted_token_ids is only used again at the very end of the CUDA kernel, I changed this to be parallelized to accelerate kernel execution.

  3. For EP, all invalid experts are filtered out directly when counting the number of experts. This accelerates the execution of moe_align_block_size and also leads to a cleaner and faster implementation for the subsequent MoE kernel.

Kernel Bench

On RTX 4090

moe-align-block-size-performance:
     num_tokens  num_experts  topk  ep_size        main         PR
0           1.0         16.0   1.0      1.0    8.192000   6.144000
1           1.0         16.0   1.0      8.0   12.288000   6.144000
2           1.0         16.0   2.0      1.0    7.168000   6.080000
3           1.0         16.0   2.0      8.0   11.264000   5.120000
4           1.0         16.0   8.0      1.0    7.232000   5.120000
5           1.0         16.0   8.0      8.0   12.224000   6.144000
6           1.0         64.0   1.0      1.0   13.312000   9.216000
7           1.0         64.0   1.0      8.0   17.408000   9.216000
8           1.0         64.0   2.0      1.0   13.312000   9.216000
9           1.0         64.0   2.0      8.0   17.408000   9.216000
10          1.0         64.0   8.0      1.0   13.312000   9.216000
11          1.0         64.0   8.0      8.0   17.408000   9.216000
12          1.0        224.0   1.0      1.0   10.240000   7.168000
13          1.0        224.0   1.0      8.0   14.336000   7.168000
14          1.0        224.0   2.0      1.0   10.240000   7.168000
15          1.0        224.0   2.0      8.0   14.336000   7.168000
16          1.0        224.0   8.0      1.0   10.240000   7.168000
17          1.0        224.0   8.0      8.0   14.336000   7.168000
18          1.0        256.0   1.0      1.0   10.240000   7.168000
19          1.0        256.0   1.0      8.0   14.336000   8.192000
20          1.0        256.0   2.0      1.0   10.240000   7.168000
21          1.0        256.0   2.0      8.0   14.336000   8.192000
22          1.0        256.0   8.0      1.0   10.240000   7.168000
23          1.0        256.0   8.0      8.0   14.336000   8.192000
24          1.0        280.0   1.0      1.0   11.264000   7.168000
25          1.0        280.0   1.0      8.0   15.360000   7.168000
26          1.0        280.0   2.0      1.0   11.264000   7.168000
27          1.0        280.0   2.0      8.0   15.360000   7.168000
28          1.0        280.0   8.0      1.0   11.264000   7.168000
29          1.0        280.0   8.0      8.0   15.360000   7.168000
30          1.0        512.0   1.0      1.0   13.312000   7.168000
31          1.0        512.0   1.0      8.0   18.432001   7.168000
32          1.0        512.0   2.0      1.0   13.312000   7.168000
33          1.0        512.0   2.0      8.0   18.432001   7.168000
34          1.0        512.0   8.0      1.0   13.312000   7.168000
35          1.0        512.0   8.0      8.0   18.432001   7.168000
36         16.0         16.0   1.0      1.0    7.168000   6.080000
37         16.0         16.0   1.0      8.0   11.264000   6.144000
38         16.0         16.0   2.0      1.0    7.168000   5.120000
39         16.0         16.0   2.0      8.0   12.288000   5.120000
40         16.0         16.0   8.0      1.0    8.192000   6.144000
41         16.0         16.0   8.0      8.0   13.312000   7.168000
42         16.0         64.0   1.0      1.0   13.312000   9.216000
43         16.0         64.0   1.0      8.0   17.408000   9.216000
44         16.0         64.0   2.0      1.0   13.312000   9.216000
45         16.0         64.0   2.0      8.0   17.408000   9.216000
46         16.0         64.0   8.0      1.0   13.312000   9.312000
47         16.0         64.0   8.0      8.0   17.408000   9.216000
48         16.0        224.0   1.0      1.0   10.240000   7.168000
49         16.0        224.0   1.0      8.0   14.336000   8.192000
50         16.0        224.0   2.0      1.0   10.240000   7.168000
51         16.0        224.0   2.0      8.0   14.336000   8.192000
52         16.0        224.0   8.0      1.0   10.240000   7.168000
53         16.0        224.0   8.0      8.0   14.336000   8.192000
54         16.0        256.0   1.0      1.0   10.240000   7.168000
55         16.0        256.0   1.0      8.0   14.336000   8.192000
56         16.0        256.0   2.0      1.0   10.240000   7.168000
57         16.0        256.0   2.0      8.0   14.336000   8.192000
58         16.0        256.0   8.0      1.0   10.240000   7.168000
59         16.0        256.0   8.0      8.0   14.336000   8.192000
60         16.0        280.0   1.0      1.0   11.264000   7.168000
61         16.0        280.0   1.0      8.0   15.360000   8.192000
62         16.0        280.0   2.0      1.0   11.136000   7.168000
63         16.0        280.0   2.0      8.0   15.360000   8.192000
64         16.0        280.0   8.0      1.0   11.264000   7.168000
65         16.0        280.0   8.0      8.0   15.360000   8.192000
66         16.0        512.0   1.0      1.0   13.312000   7.168000
67         16.0        512.0   1.0      8.0   18.432001   8.192000
68         16.0        512.0   2.0      1.0   13.312000   7.168000
69         16.0        512.0   2.0      8.0   18.432001   8.192000
70         16.0        512.0   8.0      1.0   13.312000   8.128000
71         16.0        512.0   8.0      8.0   18.432001   8.192000
72        256.0         16.0   1.0      1.0    9.216000   7.168000
73        256.0         16.0   1.0      8.0   13.312000   8.192000
74        256.0         16.0   2.0      1.0   12.288000  10.240000
75        256.0         16.0   2.0      8.0   15.360000  11.264000
76        256.0         16.0   8.0      1.0    8.192000   8.192000
77        256.0         16.0   8.0      8.0   12.288000   8.192000
78        256.0         64.0   1.0      1.0   14.336000  10.240000
79        256.0         64.0   1.0      8.0   18.432001  10.736000
80        256.0         64.0   2.0      1.0   14.336000  11.264000
81        256.0         64.0   2.0      8.0   18.432001  12.288000
82        256.0         64.0   8.0      1.0    9.216000   8.192000
83        256.0         64.0   8.0      8.0   12.288000   8.192000
84        256.0        224.0   1.0      1.0   10.240000   8.192000
85        256.0        224.0   1.0      8.0   14.432000   8.192000
86        256.0        224.0   2.0      1.0   10.240000   8.192000
87        256.0        224.0   2.0      8.0   14.336000   8.192000
88        256.0        224.0   8.0      1.0   11.264000   8.192000
89        256.0        224.0   8.0      8.0   15.360000   8.192000
90        256.0        256.0   1.0      1.0   10.240000   8.192000
91        256.0        256.0   1.0      8.0   15.360000   8.192000
92        256.0        256.0   2.0      1.0   10.240000   8.192000
93        256.0        256.0   2.0      8.0   15.360000   8.192000
94        256.0        256.0   8.0      1.0   11.264000   9.216000
95        256.0        256.0   8.0      8.0   16.384000   8.192000
96        256.0        280.0   1.0      1.0   11.264000   8.192000
97        256.0        280.0   1.0      8.0   15.360000   8.192000
98        256.0        280.0   2.0      1.0   11.264000   9.216000
99        256.0        280.0   2.0      8.0   16.384000   9.216000
100       256.0        280.0   8.0      1.0   11.264000   9.216000
101       256.0        280.0   8.0      8.0   16.527999   9.216000
102       256.0        512.0   1.0      1.0   13.312000   8.192000
103       256.0        512.0   1.0      8.0   18.592000   8.192000
104       256.0        512.0   2.0      1.0   13.312000  11.264000
105       256.0        512.0   2.0      8.0   19.455999  11.264000
106       256.0        512.0   8.0      1.0   14.336000  11.264000
107       256.0        512.0   8.0      8.0   19.455999  11.264000
108      4096.0         16.0   1.0      1.0    9.216000   9.216000
109      4096.0         16.0   1.0      8.0   13.312000   9.216000
110      4096.0         16.0   2.0      1.0   11.264000  11.264000
111      4096.0         16.0   2.0      8.0   15.360000  10.240000
112      4096.0         16.0   8.0      1.0   22.528000  20.608000
113      4096.0         16.0   8.0      8.0   26.624000  19.455999
114      4096.0         64.0   1.0      1.0   10.240000   9.216000
115      4096.0         64.0   1.0      8.0   14.336000   9.216000
116      4096.0         64.0   2.0      1.0   12.288000  11.264000
117      4096.0         64.0   2.0      8.0   16.384000  10.240000
118      4096.0         64.0   8.0      1.0   25.599999  23.552001
119      4096.0         64.0   8.0      8.0   30.719999  18.432001
120      4096.0        224.0   1.0      1.0   12.288000   9.216000
121      4096.0        224.0   1.0      8.0   16.384000   9.216000
122      4096.0        224.0   2.0      1.0   14.336000  11.264000
123      4096.0        224.0   2.0      8.0   18.432001  10.240000
124      4096.0        224.0   8.0      1.0   24.576001  20.479999
125      4096.0        224.0   8.0      8.0   28.672000  19.455999
126      4096.0        256.0   1.0      1.0   12.288000   9.216000
127      4096.0        256.0   1.0      8.0   17.408000   9.216000
128      4096.0        256.0   2.0      1.0   14.336000  11.264000
129      4096.0        256.0   2.0      8.0   19.455999  10.240000
130      4096.0        256.0   8.0      1.0   27.648000  23.552001
131      4096.0        256.0   8.0      8.0   32.768000  19.455999
132      4096.0        280.0   1.0      1.0   12.288000   9.216000
133      4096.0        280.0   1.0      8.0   17.408000   9.216000
134      4096.0        280.0   2.0      1.0   15.360000  11.264000
135      4096.0        280.0   2.0      8.0   19.616000  10.240000
136      4096.0        280.0   8.0      1.0   25.599999  21.504000
137      4096.0        280.0   8.0      8.0   31.727999  19.455999
138      4096.0        512.0   1.0      1.0   15.360000  12.288000
139      4096.0        512.0   1.0      8.0   20.479999  12.192000
140      4096.0        512.0   2.0      1.0   17.408000  13.280000
141      4096.0        512.0   2.0      8.0   22.528000  12.288000
142      4096.0        512.0   8.0      1.0   28.672000  21.504000
143      4096.0        512.0   8.0      8.0   34.655999  19.455999

Kernel Accuracy Test

Tested with

pytest -sv test_moe_align_block_size.py::test_moe_align_block_size

pytest -sv test_moe_align_block_size.py::test_moe_align_block_size_with_expert_map

All test cases are passed.

E2E Accuracy Test

GSM8K (2-shot)

With Triton Kernel

vllm serve /data/pretrained_model/Qwen/Qwen1.5-MoE-A2.7B-Chat/ -tp 4 --served-model-name model --port 8002 --gpu-memory-utilization 0.75 --enable-expert-parallel


lm_eval \
    --model local-chat-completions \
    --tasks gsm8k \
    --num_fewshot 2 \
    --batch_size auto \
    --model_args "model=model,base_url=http://localhost:8002/v1/chat/completions,max_gen_toks=512,num_concurrent=128,max_length=4096" \
    --apply_chat_template


|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     2|exact_match|↑  |0.5095|±  |0.0138|
|     |       |strict-match    |     2|exact_match|↑  |0.0258|±  |0.0044|

With Marlin Kernel

vllm serve /data/pretrained_model/Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4/ -tp 4 --served-model-name model --port 8002 --gpu-memory-utilization 0.75  --enable-expert-parallel


lm_eval \
    --model local-chat-completions \
    --tasks gsm8k \
    --num_fewshot 2 \
    --batch_size auto \
    --model_args "model=model,base_url=http://localhost:8002/v1/chat/completions,max_gen_toks=512,num_concurrent=128,max_length=4096" \
    --apply_chat_template


|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     2|exact_match|↑  |0.4723|±  |0.0138|
|     |       |strict-match    |     2|exact_match|↑  |0.0190|±  |0.0038|

@mergify mergify bot added the performance Performance-related issues label Nov 28, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces several well-implemented optimizations to moe_align_block_size, resulting in significant performance gains as demonstrated by the benchmark results. The optimizations include using a tighter memory allocation for small batches, parallelizing data initialization within the CUDA kernels, and filtering invalid experts earlier in the expert parallelism path. Additionally, this PR includes a critical correctness fix for expert parallelism mode by ensuring an intermediate buffer is zero-initialized, preventing potential errors from uninitialized memory. The changes are clean, well-reasoned, and thoroughly tested. Overall, this is an excellent contribution that improves both performance and correctness.

Signed-off-by: Jinzhen Lin <[email protected]>
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance-related issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants