Skip to content

Conversation

@kyle-256
Copy link

Motivation

fix an error of sink attention backward when set casual=true

@kyle-256 kyle-256 requested review from a team and Copilot January 14, 2026 03:13
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a bug in the Triton kernel implementation for sink attention backward pass when the causal flag is set to true. The issue involved incorrect block size calculations that were being divided by BLK_SLICE_FACTOR, which caused errors in the masked operations.

Changes:

  • Corrected mask block size calculations in the backward causal kernel by removing unnecessary division by BLK_SLICE_FACTOR

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1
Copy link

Copilot AI Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Corrected spelling of 'casual' to 'causal' in PR title and description. The PR metadata contains 'casual=true' which should be 'causal=true'.

Copilot uses AI. Check for mistakes.
@kyle-256 kyle-256 changed the title [TRITON] fix sink_attn error when caual=true [TRITON] fix sink_attn error when causal=true Jan 14, 2026
Copy link
Contributor

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello @kyle-256. Thanks for your PR! My review follows.

UT failures

Can you please provide more details about the UT failures you've been facing? What Triton compiler are you using? I wasn't able to reproduce any UT failure with latest Triton and latest AITER. You can check details bellow.

Triton commit: triton-lang/triton@20251a3

AITER commit: da29487

Test results on MI300:

root@f799ed2bcfbf:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
        MARKET_NAME: AMD Instinct MI300X
root@f799ed2bcfbf:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected

op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................                                                                             [100%]

=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
  /workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================= 192 passed, 9505 deselected, 4 warnings in 89.89s (0:01:29) ==========================

Test results on MI350:

root@ff4dcd1c1607:/workspace/aiter# amd-smi static | grep -i market | sort | uniq
        MARKET_NAME: AMD Instinct MI355X
root@ff4dcd1c1607:/workspace/aiter# pytest op_tests/triton_tests/attention/test_mha.py -k with_sink
============================================= test session starts ==============================================
platform linux -- Python 3.12.3, pytest-9.0.1, pluggy-1.6.0
rootdir: /workspace/aiter
configfile: pyproject.toml
plugins: hypothesis-6.148.3
collected 9697 items / 9505 deselected / 192 selected

op_tests/triton_tests/attention/test_mha.py ............................................................ [ 31%]
........................................................................................................ [ 85%]
............................                                                                             [100%]

=============================================== warnings summary ===============================================
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
../triton/python/triton/runtime/autotuner.py:101
  /workspace/triton/python/triton/runtime/autotuner.py:101: DeprecationWarning: warmup, rep, and use_cuda_graph parameters are deprecated. See https://github.com/triton-lang/triton/pull/4496 for details.
    warnings.warn(("warmup, rep, and use_cuda_graph parameters are deprecated. See "

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================== 192 passed, 9505 deselected, 4 warnings in 44.28s ===============================

Changing BLK_SLICE_FACTOR may affect kernel performance

Can you please do some profiling on both MI300 and MI350, so we can be sure that changing BLK_SLICE_FACTOR to 1 doesn't introduce performance regressions? You can use op_tests/op_benchmarks/triton/bench_mha.py benchmark script to get performance data.

What are your target shapes?

descale_q, descale_k, descale_v, descale_do = 1.0, 1.0, 1.0, 1.0

MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
MASK_BLOCK_M1: tl.constexpr = BLOCK_M1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the proper way of doing this change is by setting BLK_SLICE_FACTOR to 1 instead of removing BLK_SLICE_FACTOR. Please take a look at aiter/ops/triton/configs/gfx942-MHA-DEFAULT.json and aiter/ops/triton/configs/gfx950-MHA-DEFAULT.json config files (bkwd_onekernelonekernelBLK_SLICE_FACTOR).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BLK_SLICE_FACTOR is a performance tuning parameter, it's important to keep it.

@brunomazzottiamd
Copy link
Contributor

@kyle-256, I'm double checking performance. I'll post my results in the PR as soon as I get them.

@brunomazzottiamd
Copy link
Contributor

brunomazzottiamd commented Jan 15, 2026

Sharing my benchmarking numbers:

(all UTs passing with BLK_SLICE_FACTOR=2 and BLK_SLICE_FACTOR=1)

MI300

TP B Layout BLK_SLICE_FACTOR=2 time (ms) BLK_SLICE_FACTOR=1 time (ms) Speedup
1 1 bshd 8.63 8.65 1.00
1 8 thd 66.07 66.76 0.99
1 9 thd 74.65 74.93 1.00
1 10 thd 81.90 83.05 0.99
1 11 thd 90.58 91.42 0.99
1 12 thd 98.60 99.71 0.99
1 13 thd 106.31 107.45 0.99
1 14 thd 115.05 116.61 0.99
1 15 thd 123.60 124.18 1.00
1 16 thd 130.91 132.06 0.99
8 1 bshd 5.23 5.34 0.98
8 8 thd 10.78 11.28 0.96
8 9 thd 11.16 11.74 0.95
8 10 thd 12.73 12.98 0.98
8 11 thd 14.54 14.99 0.97
8 12 thd 15.61 16.26 0.96
8 13 thd 15.71 16.51 0.95
8 14 thd 16.69 17.49 0.95
8 15 thd 18.05 18.37 0.98
8 16 thd 19.66 20.33 0.97
Geomean 0.98
  • BLK_SLICE_FACTOR=1 is 2% slower than BLK_SLICE_FACTOR=2.

MI350

TP B Layout BLK_SLICE_FACTOR=2 time (ms) BLK_SLICE_FACTOR=1 time (ms) Speedup
1 1 bshd 5.62 5.38 1.05
1 8 thd 44.93 42.04 1.07
1 9 thd 50.63 48.50 1.04
1 10 thd 56.11 53.56 1.05
1 11 thd 61.48 58.33 1.05
1 12 thd 67.44 62.92 1.07
1 13 thd 72.87 71.05 1.03
1 14 thd 78.16 74.48 1.05
1 15 thd 83.97 85.10 0.99
1 16 thd 89.15 89.57 1.00
8 1 bshd 4.48 3.35 1.34
8 8 thd 7.70 14.36 0.54
8 9 thd 10.05 9.02 1.11
8 10 thd 10.95 9.63 1.14
8 11 thd 10.75 9.71 1.11
8 12 thd 12.42 10.26 1.21
8 13 thd 12.88 11.86 1.09
8 14 thd 16.40 12.41 1.32
8 15 thd 16.99 13.47 1.26
8 16 thd 15.87 13.07 1.21
Geomean 1.07
  • BLK_SLICE_FACTOR=1 is 7% faster than BLK_SLICE_FACTOR=2.
  • However, for (TP=8, B=8, Layout=thd) case we have a big regression, speedup is 0.54.
  • Excluding (TP=8, B=8, Layout=thd) case, speedup geomean is 1.11.

I used the following dirty script to get performance numbers:

#!/usr/bin/env bash

mode='bwd'
dtype='bf16'
sq=8192
sk="${sq}"
d=64
causal='yes'
common_args="--metric time -mode ${mode} --dtype ${dtype} -sq ${sq} -sk ${sk} -d ${d} -causal ${causal}"

echo "tp,b,layout,time_ms"

for tp in 1 8; do
    args="${common_args}"

    if [[ "${tp}" -eq 1 ]]; then
        hq=64
        hk=8
    elif [[ "${tp}" -eq 8 ]]; then
        hq=8
        hk=1
    fi
    args="${args} -hq ${hq} -hk ${hk}"

    for layout in 'bshd' 'thd'; do
        args="${args} --layout ${layout}"

        if [[ "${layout}" == 'bshd' ]]; then
            batch_sizes=(1)
        elif [[ "${layout}" == 'thd' ]]; then
            mapfile -t batch_sizes < <(seq 8 16)
        fi

        for b in "${batch_sizes[@]}"; do
            args="${args} -b ${b}"

            # shellcheck disable=SC2086
            time_ms=$(python op_tests/op_benchmarks/triton/bench_mha.py ${args} \
                2>&1 | tail -1 | tr --squeeze-repeats ' ' | cut --delimiter ' ' --fields 7)
            echo "${tp},${b},${layout},${time_ms}"
        done
    done
done

Feel free to do your own experiments and test your target shapes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants