-
Notifications
You must be signed in to change notification settings - Fork 118
[Triton] FA v3 API #1065
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Triton] FA v3 API #1065
Conversation
c834e99
to
9e0b3f4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for Flash Attention v3 API with FP8 quantization and paged attention capabilities. The changes modify existing code to conform to the Flash Attention v3 interface while maintaining backward compatibility with v2.
- Updates MHA tests to use v3 API with FP8 support and new quantization utilities
- Introduces new v3 interfaces with proper FP8 descaling and KV cache functionality
- Adds comprehensive FP8 quantization utilities for different tensor layouts
Reviewed Changes
Copilot reviewed 13 out of 15 changed files in this pull request and generated 3 comments.
Show a summary per file
File | Description |
---|---|
op_tests/triton_tests/test_mha.py | Updates test suite to support both v2 and v3 APIs with FP8 quantization |
op_tests/op_benchmarks/triton/bench_mha.py | Updates benchmark code to use v3 FP8 quantization API |
aiter/ops/triton/utils/mha_kernel_utils.py | Adds FP8 quantization and dequantization utilities |
aiter/ops/triton/mha_v3.py | Implements Flash Attention v3 interface with FP8 support |
aiter/ops/triton/mha.py | Removes deprecated FP8 functions and adds v2 KV cache support |
Multiple triton kernel files | Adds comprehensive v2/v3 interfaces and FP8 prefill implementation |
Comments suppressed due to low confidence (1)
aiter/ops/triton/mha_v3.py:1
- The union syntax
int | None
should be replaced withOptional[int]
for consistent type annotation style and Python version compatibility.
# SPDX-License-Identifier: MIT
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
FA V3(fp8) and paged Attention FP8 Prefill work compressed Fa V3 api Compress fp8 work so far pull cast out of torch function e2e fp8 stub emulate fa v3 ignore remove example clean up forward save fp8 backward ignore train artifacts just use return_attn_probs match fa behvaior save fa ref add fa_ref fix dropout bug add link optional fp8 p descale rename to v3 fa v3 clean up match backward min diff update varlen api clean up FP8_P_DESCALE update bench and test lint fix mha varlen bug remove .gitignore save lint remove skip bring back skips add fa module update v2 interface create mha_v3 add old v3 path update fa module tests passing sync bwd changes lint fa module add kvcache api and test fix lint fp8 works test fp8 only add paged tests add flash_attn_with_kvcache to v2 test varlen move to _triton_kernels test_mha_backward working with v3 upgrade to cleanedup modeule get test_mha_backward_varlen working clean up fix lint bug move casting functions to utils fix lint Update aiter/ops/triton/utils/mha_kernel_utils.py Co-authored-by: Copilot <[email protected]> Update aiter/ops/triton/utils/mha_kernel_utils.py Co-authored-by: Copilot <[email protected]> use Optional update from main_perf lint
0fd3a58
to
f3c41c0
Compare
Motivation
Add support for fp8 and paged attention in Flash Attention
Technical Details
Modify existing code so that it confirms to the flash attention v3 api. A user provides fp8 values for q, k and v and their descale values.
Test Plan
update mha tests and bench code
Test Result
Submission Checklist