Skip to content

Conversation

@cathalobrien
Copy link
Contributor

Description

This PR adds an option to the Triton GT kernel to normalise Q and K.

Currently, when QK normalisation is done in Anemoi (e.g. ensemble processor), a layernorm is computed over q and k before calling the GT. This is inefficient as it forces us to load all of Q and K twice (once for normalisation and once for computing GT.conv), and increases the memory usage as more intermediate inputs must be saved for the backward pass.

This PR fuses QK normalisation into the GT conv, by applying normalisation right after Q and K are loaded in the GT kernel. Since loading the elements from memory is the expensive part, and that is done anyway, the runtime cost of enabling QK_norm is quite low. Benchmarking results are shown below, but it is faster and uses less memory then applying a compiled RMS norm followed by the triton GT kernel.

The normalisation method implemented is RMSNorm, which is different to the LayerNorm used currently but still widely used across ML. RMSNorm was implemented because LayerNorms arent suited to the tiled memory access pattern of the GT conv (due to maintaining global mean and variance).

The RMSNorm triton kernels are split into a different file, for clarity and future reuse. The GT pytests have been extended to test the RMSNorm option



=== Benchmarking graph: era_to_h fwd+bwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 39.23 ms / iter
        peak memory usage: 12709.78 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 9.14 ms / iter
        peak memory usage: 9886.27 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 14.84 ms / iter
        peak memory usage: 11024.06 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 12.12 ms / iter
        peak memory usage: 9886.24 MB



=== Benchmarking graph: h_to_h fwd+bwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 10.83 ms / iter
        peak memory usage: 5916.90 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 1.43 ms / iter
        peak memory usage: 3092.68 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 5.19 ms / iter
        peak memory usage: 3252.27 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 2.19 ms / iter
        peak memory usage: 3092.65 MB



=== Benchmarking graph: h_to_era fwd+bwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 52.97 ms / iter
        peak memory usage: 19149.75 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 13.76 ms / iter
        peak memory usage: 14724.34 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 19.84 ms / iter
        peak memory usage: 15863.98 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 14.39 ms / iter
        peak memory usage: 14724.31 MB



=== Benchmarking graph: dop_enc_large fwd+bwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 44.08 ms / iter
        peak memory usage: 16828.71 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 13.04 ms / iter
        peak memory usage: 13126.87 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 19.45 ms / iter
        peak memory usage: 14955.59 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 12.06 ms / iter
        peak memory usage: 13126.84 MB

base case Triton GraphTransformer with no normalisation:
apply
        elapsed time: 0.34 ms / iter
        peak memory usage: 2213.54 MB

layer_norm & Triton GraphTransformer:
apply
        elapsed time: 12.10 ms / iter
        peak memory usage: 12629.29 MB

fused qk_norm Triton GraphTransformer:
apply
        elapsed time: 0.37 ms / iter
        peak memory usage: 2213.54 MB
@github-project-automation github-project-automation bot moved this to To be triaged in Anemoi-dev Dec 19, 2025
@cathalobrien cathalobrien added the ATS Approval Not Needed No approval needed by ATS label Dec 19, 2025
@cathalobrien cathalobrien self-assigned this Dec 19, 2025
@cathalobrien cathalobrien requested a review from japols December 19, 2025 12:15
@cathalobrien
Copy link
Contributor Author

forward pass only (speedup is better here, since the bwd has to recompute the normalisation scaling factor many times)

=== Benchmarking graph: era_to_h fwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 15.72 ms / iter
        peak memory usage: 10895.90 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 1.87 ms / iter
        peak memory usage: 10006.98 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 7.92 ms / iter
        peak memory usage: 10869.22 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 1.89 ms / iter
        peak memory usage: 6098.98 MB



=== Benchmarking graph: h_to_h fwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 4.15 ms / iter
        peak memory usage: 7895.01 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 0.34 ms / iter
        peak memory usage: 3130.54 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 1.09 ms / iter
        peak memory usage: 3173.83 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 0.38 ms / iter
        peak memory usage: 2213.50 MB



=== Benchmarking graph: h_to_era fwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 19.14 ms / iter
        peak memory usage: 19823.18 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 1.92 ms / iter
        peak memory usage: 14901.28 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 5.33 ms / iter
        peak memory usage: 14725.23 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 2.09 ms / iter
        peak memory usage: 10345.85 MB



=== Benchmarking graph: dop_enc_large fwd ===

rms_norm then pyg (non-compiled):

Compiled rms_norm x pyg:
OptimizedModule
        elapsed time: 10.08 ms / iter
        peak memory usage: 17544.41 MB


 standalone Triton GraphTransformer:
apply
        elapsed time: 1.59 ms / iter
        peak memory usage: 13281.20 MB


 compiled rms norm then standalone Triton GraphTransformer:
apply
        elapsed time: 5.83 ms / iter
        peak memory usage: 14805.99 MB


 fused rms_norm x Triton GraphTransformer:
apply
        elapsed time: 1.64 ms / iter
        peak memory usage: 7801.31 MB

@HCookie HCookie moved this from To be triaged to Now In Progress in Anemoi-dev Jan 6, 2026
…on gt

TODO clean up by wrapping autograd function in an NN module which creates and manages passing the weigths to bwd pass
@cathalobrien
Copy link
Contributor Author

@japols noticed i forgot to track weights for q_norm and k_norm (see 'elementwise_affine' here).
I have implemented that now, performance and memory impact is very small since the weights are only of size 'head_dim'. It makes the code a bit uglier as there's more pointers to manage.

I also added standalone pytests just for the RMS norm separate from the triton GT

runtime and memory for fwd+bwd are shown below (col 2 is standalone triton GT, col 3 is compiled RMS norm followed by triton GT and col 4 is triton GT with fused RMS norm)

era_to_h_fwd+bwd h_to_h_fwd+bwd h_to_era_fwd+bwd dop_enc_large_fwd+bwd

@cathalobrien
Copy link
Contributor Author

There is a speedup and memory reduction in the ensemble benchmark test, where qk_norm is used in the processor

https://github.com/ecmwf/anemoi-core/actions/runs/20821574437/job/59811318231#step:3:3386

Screenshot 2026-01-08 at 16 44 48

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ATS Approval Not Needed No approval needed by ATS models

Projects

Status: Now In Progress

Development

Successfully merging this pull request may close these issues.

2 participants