-
Notifications
You must be signed in to change notification settings - Fork 78
feat(models): Triton Attention backend #716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
fused-attention-batch1-head8-d64-fwd-window=0:
N_CTX Triton [FP16] Flash-2
0 1024.0 53.951905 30.154757
1 2048.0 251.862878 225.601753
2 4096.0 1003.585900 883.967242
3 8192.0 3087.151752 2919.118612
4 16384.0 14424.333821 12164.023961
fused-attention-batch1-head8-d64-fwd-window=256:
N_CTX Triton [FP16] Flash-2
0 1024.0 66.118373 49.047213
1 2048.0 212.385098 224.781601
2 4096.0 627.400725 754.676149
3 8192.0 2403.518850 2072.301033
4 16384.0 5407.098626 4389.974473
fused-attention-batch1-head8-d64-bwd-window=0:
N_CTX Triton [FP16] Flash-2
0 1024.0 29.038600 28.145139
1 2048.0 103.598599 107.078404
2 4096.0 221.267879 466.277452
3 8192.0 238.154238 1926.255278
4 16384.0 244.473300 7453.972862
fused-attention-batch1-head8-d64-bwd-window=256:
N_CTX Triton [FP16] Flash-2
0 1024.0 31.743517 30.864637
1 2048.0 90.926897 119.329910
2 4096.0 219.786079 486.217597
3 8192.0 236.157174 1804.149605
4 16384.0 243.949151 4204.442805
fused-attention-batch1-head8-d128-fwd-window=0:
N_CTX Triton [FP16] Flash-2
0 1024.0 113.015749 97.006563
1 2048.0 355.368512 378.363736
2 4096.0 1539.535832 1366.327294
3 8192.0 7007.632047 6028.029964
4 16384.0 19762.887419 17225.973294
fused-attention-batch1-head8-d128-fwd-window=256:
N_CTX Triton [FP16] Flash-2
0 1024.0 103.508837 89.663668
1 2048.0 525.090419 464.968188
2 4096.0 1235.160998 1332.826567
3 8192.0 3203.617410 2936.990470
4 16384.0 6880.275897 6123.385195
fused-attention-batch1-head8-d128-bwd-window=0:
N_CTX Triton [FP16] Flash-2
0 1024.0 31.435884 62.969277
1 2048.0 58.816000 240.568425
2 4096.0 81.556234 987.361971
3 8192.0 86.844913 3692.221169
4 16384.0 88.609923 8978.312907
fused-attention-batch1-head8-d128-bwd-window=256:
N_CTX Triton [FP16] Flash-2
0 1024.0 31.168216 62.057629
1 2048.0 57.411547 240.228307
2 4096.0 73.313512 920.541072
3 8192.0 75.291760 2300.896123
4 16384.0 75.907589 4835.565817
for more information, see https://pre-commit.ci
…ore into feature/triton-flash-attn
…uld be off if window-size was not a factor of BLOCK_N
|
Something strange is afoot with the loss for longer runs |
|
I compared an aifs-single setup over 4 GH200s for 4000 iterations here Loss looks good, and the triton backend completed faster
|
…ring inference runtime
…ad dims. complex bc more block sizes and dependacnies between them due to sharing a gird. wont run but pytest which doesnt autotune passes
japols
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work!
Some comments on dtypes and a few minor changes.
Co-authored-by: Jan Polster <[email protected]>
Co-authored-by: Jan Polster <[email protected]>
…ore into feature/triton-flash-attn
for more information, see https://pre-commit.ci
…ore into feature/triton-flash-attn
…rid via autotuning
for more information, see https://pre-commit.ci
…ore into feature/triton-flash-attn
ssmmnn11
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
| pre_hook=_host_descriptor_pre_hook, | ||
| ) | ||
| for BM in [32, 64, 128] | ||
| for BN in [32, 64, 128] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no 16?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea. it passes pytests and gives a speedup from 8.55 ms / iter to 7.06 ms / iter at 2048 channels...maybe I should try 8 :D
| # Meaning there is at least (BATCH_SIZE * NUM_HEADS) SMs | ||
| # Depending on BLOCK_FIXED, the context window might also be split across SMs | ||
| # BLOCK_FIXED is a hyperparameter which triton sets at runtime by running small performance tests | ||
| def grid(META): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is N_CTX always divisible by BLOCK_FIXED? Do we need a mask if not in _attn_fwd?
| offset_y = off_z * (N_CTX * H) + off_h * N_CTX | ||
| qo_offset_y = offset_y + start_m * BLOCK_FIXED | ||
| # initialize offsets | ||
| offs_m = start_m * BLOCK_FIXED + tl.arange(0, BLOCK_FIXED) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need N_CTX % PRE_BLOCK == 0 here as well? or a mask to not write at invalid locations?
| # This frees up threads and registers to do other computations | ||
| # TMA requires global memory allocations, so we set the alloc_fn here | ||
| def alloc_fn(size: int, align: int, _): | ||
| return torch.empty(size, dtype=torch.int8, device="cuda") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we allocate on current device? or is cuda always save / will we always get the correct device??
| def alloc_fn(size: int, align: int, _): | ||
| return torch.empty(size, dtype=torch.int8, device="cuda") | ||
|
|
||
| triton.set_allocator(alloc_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this set in every call? should we set this once somewhere?
this is also interesting, but not really relevant I guess: pytorch/pytorch#155584?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Now its only called once when the file is imported. Currently there isnt a public way to check if the allocator has been set (here). I left a TODO to add a check once there is a way
| # -- update output accumulator -- | ||
| acc = acc * alpha[:, None] | ||
| # prepare p and v for the dot | ||
| v = desc_v.load([iter_offset, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need be carefully to not load something that goes out of bounds?
| curr_iter = tl.multiple_of(curr_iter, BLOCK_ITER) # Tells compiler curr_iter is a multiple of BLOCK_ITER | ||
|
|
||
| # -- compute qk ---- | ||
| k = desc_k.load([iter_offset, 0]).T |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need be carefully to not load something that goes out of bounds?
| qk_scale = sm_scale | ||
| qk_scale *= 1.44269504 # 1/log(2) #hack to make calculating exponent faster, by merging 1/ln(2) now the cheaper exp2() fn can be called later instead of exp() | ||
| # load q: it will stay in SRAM throughout | ||
| q = desc_q.load([fixed_offset, 0]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need a mask so we don't load anything beyond N_CTX
| @@ -1,4 +1,4 @@ | |||
| # (C) Copyright 2024 Anemoi contributors. | |||
| # (C) Copyright 2026 Anemoi contributors. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| # (C) Copyright 2026 Anemoi contributors. | |
| # (C) Copyright 2024- Anemoi contributors. |
| LOGGER = logging.getLogger(__name__) | ||
|
|
||
| # Change attention implementation during inference runtime | ||
| ATTENTION_BACKEND = os.environ.get("ANEMOI_INFERENCE_TRANSFORMER_ATTENTION_BACKEND", "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At some point I'd like to consolidate these into utils.env
| } | ||
|
|
||
| # Check if 'ANEMOI_INFERENCE_TRANSFORMER_ATTENTION_BACKEND' env var has been set | ||
| if ATTENTION_BACKEND != "": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if ATTENTION_BACKEND != "": | |
| if ATTENTION_BACKEND: |
| value = self.lin_v(x) | ||
|
|
||
| # Check at runtime if the Attention backend env var has been set, and update attention backend accordingly | ||
| if ATTENTION_BACKEND != "": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| if ATTENTION_BACKEND != "": | |
| if ATTENTION_BACKEND: |
| "Dropout probability used for multi-head self attention, default 0.0" | ||
| attention_implementation: str = Field(example="flash_attention") | ||
| "Attention implementation to use. Default to 'flash_attention'." | ||
| attention_implementation: str = Field(example="triton_attention") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As this is a required str, it doesn't really default to triton_attention
| window_size: 512 | ||
| dropout_p: 0.0 # GraphTransformer | ||
| attention_implementation: flash_attention # flash_attention, scaled_dot_product_attention | ||
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention | |
| attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention |
| window_size: 512 | ||
| dropout_p: 0.0 | ||
| attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention | ||
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention | |
| attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention |
| window_size: 512 | ||
| dropout_p: 0.0 # GraphTransformer | ||
| attention_implementation: flash_attention # flash_attention, scaled_dot_product_attention | ||
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention | |
| attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention |
| window_size: -1 | ||
| dropout_p: 0.0 | ||
| attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention | ||
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention | |
| attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention |
| window_size: -1 | ||
| dropout_p: 0.0 | ||
| attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention | ||
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention | |
| attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention |


Description
triton backend for the transformer processor. As fast as flash-attention 2 without all the hassle to install.
I based it on the official (MIT licensed) triton fused attention demo and added support for sliding window. I also changed the BW pass structure to make it simpler and easier to support other attention modifications
performance tests and loss comparisons against longer runs are shown in comments below
There is a pytest suite which tests numerous different configurations compared to a reference implementation:
This PR changes the default attention implementation when using the transformer processor to the 'triton' backend. Since the triton backend has the same performance as flash attention v2 and does not have to be installed, this will allow users to train transformer models out-of-the-box.
I also added two env vars to allow users at inference to select the attention implementation. This is a quality of life feature a few users have suggested.
In the future it would be better to replace the use of env vars with passing the information through the anemoi-inference config
📚 Documentation preview 📚: https://anemoi-training--716.org.readthedocs.build/en/716/
📚 Documentation preview 📚: https://anemoi-graphs--716.org.readthedocs.build/en/716/
📚 Documentation preview 📚: https://anemoi-models--716.org.readthedocs.build/en/716/