Skip to content

[Operator] optimized flash_mla in triton #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Jul 9, 2025
Merged

[Operator] optimized flash_mla in triton #510

merged 10 commits into from
Jul 9, 2025

Conversation

StrongSpoon
Copy link
Collaborator

@StrongSpoon StrongSpoon commented Mar 25, 2025

PR Category

Operator

Type of Change

Performance Optimization

Description

reimplemented flash_mla and got better performance than baseline.

Performance on NVIDIA A100, compared to torch.
截屏2025-03-28 17 54 25

And it achieves 78% relative performance to flash_infer (backend=fa2) in average.

Performance on NVIDIA H800, compared to all implementations.

Bandwidth(GB/s) \seqlen 1151 2175 4223 8319 16511 32895
torch 2 2 2 2 2 2
flash_mla 1726 2012 2186 2259 2240 2275
flash_infer 1579 1638 1669 1662 1671 1660
flash_mla_triton 88 94 125 150 173 190
flag_mla_triton 1107 1156 1181 1162 1156 1169

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

dv,
causal,
):
logging.debug("GEMS FLASH MLA")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to logger.debug

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to logger.debug

done

# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=120)
model_name = "deepseek-ai/DeepSeek-V3"
llm = LLM(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a heads-up: the full-version DeepSeek model typically requires multiple GPUs (e.g., 2×H100). Without some non-trivial changes to vLLM or DeepSeek’s config, this script may not run as expected.

kiddyjinjin
kiddyjinjin previously approved these changes Jun 23, 2025
kiddyjinjin
kiddyjinjin previously approved these changes Jul 2, 2025
Copy link
Collaborator

@kiddyjinjin kiddyjinjin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lg

meinie0826
meinie0826 previously approved these changes Jul 8, 2025
@StrongSpoon StrongSpoon dismissed stale reviews from meinie0826 and kiddyjinjin via 04f4787 July 8, 2025 09:10
@meinie0826 meinie0826 merged commit 73f9236 into master Jul 9, 2025
8 of 14 checks passed
@meinie0826 meinie0826 deleted the flashmla branch July 9, 2025 02:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants