Skip to content

Conversation

micmelesse
Copy link
Contributor

Motivation

Add support for fp8 and paged attention in Flash Attention

Technical Details

Modify existing code so that it confirms to the flash attention v3 api. A user provides fp8 values for q, k and v and their descale values.

Test Plan

update mha tests and bench code

Test Result

Submission Checklist

@micmelesse
Copy link
Contributor Author

micmelesse commented Sep 24, 2025

You can see examples on how to use the interface for using fp8 with regular and paged attention by looking at the tests like op_tests/triton_tests/test_mha.py

For regular attention fp8, you will see code that looks like this

from aiter.ops.triton.mha_v3 import (
    flash_attn_func as flash_attn_func_v3,
)

# forward
triton_out = flash_attn_func_v3(
    q_fp8,
    k_fp8,
    v_fp8,
    softmax_scale=None,
    causal=CAUSAL,
    q_descale=q_descale,
    k_descale=k_descale,
    v_descale=v_descale,
)

# backward
triton_dq, triton_dk, triton_dv = torch.autograd.grad(
    triton_out, (q_fp8, k_fp8, v_fp8), do.clone()
)

Here is a small model trained on wikitext to test convergence.

combined_loss

for paged attention fp8 which is available in the inference api flash_attn_with_kvcache , you will see code that looks like this

from aiter.ops.triton.mha_v3 import (
    flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)

# forward
out_kernel = flash_attn_with_kvcache_v3(
            q_fp8,
            k_cache_fp8,
            v_cache_fp8,
            cache_seqlens=cache_seqlens,
            causal=causal,
            q_descale=q_descale,
            k_descale=k_descale,
            v_descale=v_descale,
            page_table=page_table,
        )

@micmelesse micmelesse mentioned this pull request Sep 24, 2025
1 task
@micmelesse micmelesse marked this pull request as ready for review September 24, 2025 03:10
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for Flash Attention v3 API with FP8 quantization and paged attention capabilities. The changes modify existing code to conform to the Flash Attention v3 interface while maintaining backward compatibility with v2.

  • Updates MHA tests to use v3 API with FP8 support and new quantization utilities
  • Introduces new v3 interfaces with proper FP8 descaling and KV cache functionality
  • Adds comprehensive FP8 quantization utilities for different tensor layouts

Reviewed Changes

Copilot reviewed 13 out of 15 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/triton_tests/test_mha.py Updates test suite to support both v2 and v3 APIs with FP8 quantization
op_tests/op_benchmarks/triton/bench_mha.py Updates benchmark code to use v3 FP8 quantization API
aiter/ops/triton/utils/mha_kernel_utils.py Adds FP8 quantization and dequantization utilities
aiter/ops/triton/mha_v3.py Implements Flash Attention v3 interface with FP8 support
aiter/ops/triton/mha.py Removes deprecated FP8 functions and adds v2 KV cache support
Multiple triton kernel files Adds comprehensive v2/v3 interfaces and FP8 prefill implementation
Comments suppressed due to low confidence (1)

aiter/ops/triton/mha_v3.py:1

  • The union syntax int | None should be replaced with Optional[int] for consistent type annotation style and Python version compatibility.
# SPDX-License-Identifier: MIT

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

FA V3(fp8) and paged Attention

FP8 Prefill work compressed

Fa V3 api

Compress fp8 work so far

pull cast out of torch function

e2e fp8 stub

emulate fa v3

ignore

remove example

clean up forward

save

fp8 backward

ignore train artifacts

just use return_attn_probs

match fa behvaior

save fa ref

add fa_ref

fix dropout bug

add link

optional fp8 p descale

rename to v3

fa v3

clean up

match backward

min diff

update varlen api

clean up FP8_P_DESCALE

update bench and test

lint

fix mha varlen bug

remove .gitignore

save

lint

remove skip

bring back skips

add fa module

update v2 interface

create mha_v3

add old v3 path

update fa module

tests passing

sync bwd changes

lint fa module

add kvcache api and test

fix lint

fp8 works

test fp8 only

add paged tests

add flash_attn_with_kvcache to v2

test varlen

move to _triton_kernels

test_mha_backward working with v3

upgrade to cleanedup modeule

get test_mha_backward_varlen working

clean up

fix lint bug

move casting functions to utils

fix lint

Update aiter/ops/triton/utils/mha_kernel_utils.py

Co-authored-by: Copilot <[email protected]>

Update aiter/ops/triton/utils/mha_kernel_utils.py

Co-authored-by: Copilot <[email protected]>

use Optional

update from main_perf

lint
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants