-
Notifications
You must be signed in to change notification settings - Fork 180
[TRITON] fix sink_attn error when causal=true #1837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR fixes a bug in the Triton kernel implementation for sink attention backward pass when the causal flag is set to true. The issue involved incorrect block size calculations that were being divided by BLK_SLICE_FACTOR, which caused errors in the masked operations.
Changes:
- Corrected mask block size calculations in the backward causal kernel by removing unnecessary division by
BLK_SLICE_FACTOR
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 | ||
|
|
||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | ||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 |
Copilot
AI
Jan 14, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Corrected spelling of 'casual' to 'causal' in PR title and description. The PR metadata contains 'casual=true' which should be 'causal=true'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hello @kyle-256. Thanks for your PR! My review follows.
UT failures
Can you please provide more details about the UT failures you've been facing? What Triton compiler are you using? I wasn't able to reproduce any UT failure with latest Triton and latest AITER. You can check details bellow.
Triton commit: triton-lang/triton@20251a3
AITER commit: da29487
Test results on MI300:
root@f799ed2bcfbf:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
MARKET_NAME: AMD Instinct MI300X
root@f799ed2bcfbf:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected
op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................ [100%]
=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
/workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= 192 passed, 9505 deselected, 4 warnings in 89.89s (0:01:29) ==========================
Test results on MI350:
root@ff4dcd1c1607:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
MARKET_NAME: AMD Instinct MI355X
root@ff4dcd1c1607:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected
op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................ [100%]
=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
/workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== 192 passed, 9505 deselected, 4 warnings in 44.28s ===============================
Changing BLK_SLICE_FACTOR may affect kernel performance
Can you please do some profiling on both MI300 and MI350, so we can be sure that changing BLK_SLICE_FACTOR to 1 doesn't introduce performance regressions? You can use op_tests/op_benchmarks/triton/bench_mha.py benchmark script to get performance data.
What are your target shapes?
| descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0 | ||
|
|
||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR | ||
| MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the proper way of doing this change is by setting BLK_SLICE_FACTOR to 1 instead of removing BLK_SLICE_FACTOR. Please take a look at aiter/ops/triton/configs/gfx942-MHA-DEFAULT.json and aiter/ops/triton/configs/gfx950-MHA-DEFAULT.json config files (bkwd_onekernel → onekernel → BLK_SLICE_FACTOR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BLK_SLICE_FACTOR is a performance tuning parameter, it's important to keep it.
|
@kyle-256, I'm double checking performance. I'll post my results in the PR as soon as I get them. |
|
Sharing my benchmarking numbers: (all UTs passing with MI300
MI350
I used the following dirty script to get performance numbers: #!/usr/bin/env bash
mode='bwd'
dtype='bf16'
sq=8192
sk="${sq}"
d=64
causal='yes'
common_args="--metric time -mode ${mode} --dtype ${dtype} -sq ${sq} -sk ${sk} -d ${d} -causal ${causal}"
echo "tp,b,layout,time_ms"
for tp in 1 8; do
args="${common_args}"
if [[ "${tp}" -eq 1 ]]; then
hq=64
hk=8
elif [[ "${tp}" -eq 8 ]]; then
hq=8
hk=1
fi
args="${args} -hq ${hq} -hk ${hk}"
for layout in 'bshd' 'thd'; do
args="${args} --layout ${layout}"
if [[ "${layout}" == 'bshd' ]]; then
batch_sizes=(1)
elif [[ "${layout}" == 'thd' ]]; then
mapfile -t batch_sizes < <(seq 8 16)
fi
for b in "${batch_sizes[@]}"; do
args="${args} -b ${b}"
# shellcheck disable=SC2086
time_ms=$(python op_tests/op_benchmarks/triton/bench_mha.py ${args} \
2>&1 | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter ' ' --fields 7)
echo "${tp},${b},${layout},${time_ms}"
done
done
doneFeel free to do your own experiments and test your target shapes. |
Motivation
fix an error of sink attention backward when set casual=true