-
Notifications
You must be signed in to change notification settings - Fork 113
[Operator] optimized flash_mla in triton #510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
src/flag_gems/fused/flash_mla.py
Outdated
dv, | ||
causal, | ||
): | ||
logging.debug("GEMS FLASH MLA") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to logger.debug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to logger.debug
done
# Create a sampling params object. | ||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=120) | ||
model_name = "deepseek-ai/DeepSeek-V3" | ||
llm = LLM( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a heads-up: the full-version DeepSeek model typically requires multiple GPUs (e.g., 2×H100). Without some non-trivial changes to vLLM or DeepSeek’s config, this script may not run as expected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lg
PR Category
Operator
Type of Change
Performance Optimization
Description
reimplemented flash_mla and got better performance than baseline.
Performance on NVIDIA A100, compared to torch.

And it achieves 78% relative performance to flash_infer (backend=fa2) in average.
Performance on NVIDIA H800, compared to all implementations.
Issue
Progress
Performance