Skip to content

Conversation

@cathalobrien
Copy link
Contributor

@cathalobrien cathalobrien commented Dec 3, 2025

Description

triton backend for the transformer processor. As fast as flash-attention 2 without all the hassle to install.

I based it on the official (MIT licensed) triton fused attention demo and added support for sliding window. I also changed the BW pass structure to make it simpler and easier to support other attention modifications

performance tests and loss comparisons against longer runs are shown in comments below

There is a pytest suite which tests numerous different configurations compared to a reference implementation:

========== 144 passed, 240 skipped, 1 warning in 19.01s ==========

This PR changes the default attention implementation when using the transformer processor to the 'triton' backend. Since the triton backend has the same performance as flash attention v2 and does not have to be installed, this will allow users to train transformer models out-of-the-box.

I also added two env vars to allow users at inference to select the attention implementation. This is a quality of life feature a few users have suggested.

2026-01-13 15:33:22 INFO 'ANEMOI_INFERENCE_GRAPHTRANSFORMER_ATTENTION_BACKEND' environment variable has been set. Overwriting attention backend from triton to pyg

2026-01-13 15:27:07 INFO 'ANEMOI_INFERENCE_TRANSFORMER_ATTENTION_BACKEND' environment variable has been set. Overwriting attention backend from triton to flash_attention

In the future it would be better to replace the use of env vars with passing the information through the anemoi-inference config


📚 Documentation preview 📚: https://anemoi-training--716.org.readthedocs.build/en/716/


📚 Documentation preview 📚: https://anemoi-graphs--716.org.readthedocs.build/en/716/


📚 Documentation preview 📚: https://anemoi-models--716.org.readthedocs.build/en/716/

fused-attention-batch1-head8-d64-fwd-window=0:
     N_CTX  Triton [FP16]       Flash-2
0   1024.0      53.951905     30.154757
1   2048.0     251.862878    225.601753
2   4096.0    1003.585900    883.967242
3   8192.0    3087.151752   2919.118612
4  16384.0   14424.333821  12164.023961
fused-attention-batch1-head8-d64-fwd-window=256:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0      66.118373    49.047213
1   2048.0     212.385098   224.781601
2   4096.0     627.400725   754.676149
3   8192.0    2403.518850  2072.301033
4  16384.0    5407.098626  4389.974473
fused-attention-batch1-head8-d64-bwd-window=0:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0      29.038600    28.145139
1   2048.0     103.598599   107.078404
2   4096.0     221.267879   466.277452
3   8192.0     238.154238  1926.255278
4  16384.0     244.473300  7453.972862
fused-attention-batch1-head8-d64-bwd-window=256:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0      31.743517    30.864637
1   2048.0      90.926897   119.329910
2   4096.0     219.786079   486.217597
3   8192.0     236.157174  1804.149605
4  16384.0     243.949151  4204.442805
fused-attention-batch1-head8-d128-fwd-window=0:
     N_CTX  Triton [FP16]       Flash-2
0   1024.0     113.015749     97.006563
1   2048.0     355.368512    378.363736
2   4096.0    1539.535832   1366.327294
3   8192.0    7007.632047   6028.029964
4  16384.0   19762.887419  17225.973294
fused-attention-batch1-head8-d128-fwd-window=256:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0     103.508837    89.663668
1   2048.0     525.090419   464.968188
2   4096.0    1235.160998  1332.826567
3   8192.0    3203.617410  2936.990470
4  16384.0    6880.275897  6123.385195
fused-attention-batch1-head8-d128-bwd-window=0:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0      31.435884    62.969277
1   2048.0      58.816000   240.568425
2   4096.0      81.556234   987.361971
3   8192.0      86.844913  3692.221169
4  16384.0      88.609923  8978.312907
fused-attention-batch1-head8-d128-bwd-window=256:
     N_CTX  Triton [FP16]      Flash-2
0   1024.0      31.168216    62.057629
1   2048.0      57.411547   240.228307
2   4096.0      73.313512   920.541072
3   8192.0      75.291760  2300.896123
4  16384.0      75.907589  4835.565817
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Dec 3, 2025
@cathalobrien cathalobrien added the ATS Approval Needed Approval needed by ATS label Dec 3, 2025
@github-actions github-actions bot added the models label Dec 3, 2025
@cathalobrien cathalobrien changed the title feat(training)/triton-attn feat(models)/triton-attn Dec 3, 2025
@cathalobrien cathalobrien changed the title feat(models)/triton-attn feat(models): Triton Attention backend Dec 3, 2025
@HCookie HCookie moved this from To be triaged to Reviewers needed in Anemoi-dev Dec 4, 2025
@cathalobrien
Copy link
Contributor Author

Something strange is afoot with the loss for longer runs
Screenshot 2025-12-05 at 14 05 47

@cathalobrien
Copy link
Contributor Author

I compared an aifs-single setup over 4 GH200s for 4000 iterations here

Loss looks good, and the triton backend completed faster

Screenshot 2026-01-13 at 16 12 45

…ad dims. complex bc more block sizes and dependacnies between them due to sharing a gird. wont run but pytest which doesnt autotune passes
Copy link
Member

@japols japols left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work!

Some comments on dtypes and a few minor changes.

@github-project-automation github-project-automation bot moved this from Reviewers needed to Under Review in Anemoi-dev Jan 14, 2026
cathalobrien and others added 19 commits January 14, 2026 13:56
Copy link
Member

@ssmmnn11 ssmmnn11 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

pre_hook=_host_descriptor_pre_hook,
)
for BM in [32, 64, 128]
for BN in [32, 64, 128]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no 16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea. it passes pytests and gives a speedup from 8.55 ms / iter to 7.06 ms / iter at 2048 channels...maybe I should try 8 :D

# Meaning there is at least (BATCH_SIZE * NUM_HEADS) SMs
# Depending on BLOCK_FIXED, the context window might also be split across SMs
# BLOCK_FIXED is a hyperparameter which triton sets at runtime by running small performance tests
def grid(META):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is N_CTX always divisible by BLOCK_FIXED? Do we need a mask if not in _attn_fwd?

offset_y = off_z * (N_CTX * H) + off_h * N_CTX
qo_offset_y = offset_y + start_m * BLOCK_FIXED
# initialize offsets
offs_m = start_m * BLOCK_FIXED + tl.arange(0, BLOCK_FIXED)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need N_CTX % PRE_BLOCK == 0 here as well? or a mask to not write at invalid locations?

# This frees up threads and registers to do other computations
# TMA requires global memory allocations, so we set the alloc_fn here
def alloc_fn(size: int, align: int, _):
return torch.empty(size, dtype=torch.int8, device="cuda")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we allocate on current device? or is cuda always save / will we always get the correct device??

def alloc_fn(size: int, align: int, _):
return torch.empty(size, dtype=torch.int8, device="cuda")

triton.set_allocator(alloc_fn)
Copy link
Member

@ssmmnn11 ssmmnn11 Jan 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this set in every call? should we set this once somewhere?

this is also interesting, but not really relevant I guess: pytorch/pytorch#155584?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Now its only called once when the file is imported. Currently there isnt a public way to check if the allocator has been set (here). I left a TODO to add a check once there is a way

# -- update output accumulator --
acc = acc * alpha[:, None]
# prepare p and v for the dot
v = desc_v.load([iter_offset, 0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need be carefully to not load something that goes out of bounds?

curr_iter = tl.multiple_of(curr_iter, BLOCK_ITER) # Tells compiler curr_iter is a multiple of BLOCK_ITER

# -- compute qk ----
k = desc_k.load([iter_offset, 0]).T
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need be carefully to not load something that goes out of bounds?

qk_scale = sm_scale
qk_scale *= 1.44269504 # 1/log(2) #hack to make calculating exponent faster, by merging 1/ln(2) now the cheaper exp2() fn can be called later instead of exp()
# load q: it will stay in SRAM throughout
q = desc_q.load([fixed_offset, 0])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need a mask so we don't load anything beyond N_CTX

@HCookie HCookie self-requested a review January 21, 2026 16:32
@@ -1,4 +1,4 @@
# (C) Copyright 2024 Anemoi contributors.
# (C) Copyright 2026 Anemoi contributors.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# (C) Copyright 2026 Anemoi contributors.
# (C) Copyright 2024- Anemoi contributors.

LOGGER = logging.getLogger(__name__)

# Change attention implementation during inference runtime
ATTENTION_BACKEND = os.environ.get("ANEMOI_INFERENCE_TRANSFORMER_ATTENTION_BACKEND", "")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At some point I'd like to consolidate these into utils.env

}

# Check if 'ANEMOI_INFERENCE_TRANSFORMER_ATTENTION_BACKEND' env var has been set
if ATTENTION_BACKEND != "":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ATTENTION_BACKEND != "":
if ATTENTION_BACKEND:

value = self.lin_v(x)

# Check at runtime if the Attention backend env var has been set, and update attention backend accordingly
if ATTENTION_BACKEND != "":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if ATTENTION_BACKEND != "":
if ATTENTION_BACKEND:

"Dropout probability used for multi-head self attention, default 0.0"
attention_implementation: str = Field(example="flash_attention")
"Attention implementation to use. Default to 'flash_attention'."
attention_implementation: str = Field(example="triton_attention")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As this is a required str, it doesn't really default to triton_attention

window_size: 512
dropout_p: 0.0 # GraphTransformer
attention_implementation: flash_attention # flash_attention, scaled_dot_product_attention
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention

window_size: 512
dropout_p: 0.0
attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention

window_size: 512
dropout_p: 0.0 # GraphTransformer
attention_implementation: flash_attention # flash_attention, scaled_dot_product_attention
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention

window_size: -1
dropout_p: 0.0
attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention

window_size: -1
dropout_p: 0.0
attention_implementation: flash_attention # Possible values: scaled_dot_product_attention, flash_attention
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
attention_implementation: triton # Possible values: flash_attention, triton, scaled_dot_product_attention
attention_implementation: triton_attention # Possible values: flash_attention, triton_attention, scaled_dot_product_attention

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

Status: Under Review

Development

Successfully merging this pull request may close these issues.

6 participants