-
Notifications
You must be signed in to change notification settings - Fork 78
feat(models): Add fused qk_norm to triton GT kernel #770
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
base case Triton GraphTransformer with no normalisation:
apply
elapsed time: 0.34 ms / iter
peak memory usage: 2213.54 MB
layer_norm & Triton GraphTransformer:
apply
elapsed time: 12.10 ms / iter
peak memory usage: 12629.29 MB
fused qk_norm Triton GraphTransformer:
apply
elapsed time: 0.37 ms / iter
peak memory usage: 2213.54 MB
|
forward pass only (speedup is better here, since the bwd has to recompute the normalisation scaling factor many times) |
…on gt TODO clean up by wrapping autograd function in an NN module which creates and manages passing the weigths to bwd pass
|
@japols noticed i forgot to track weights for q_norm and k_norm (see 'elementwise_affine' here). I also added standalone pytests just for the RMS norm separate from the triton GT runtime and memory for fwd+bwd are shown below (col 2 is standalone triton GT, col 3 is compiled RMS norm followed by triton GT and col 4 is triton GT with fused RMS norm)
|
|
There is a speedup and memory reduction in the ensemble benchmark test, where qk_norm is used in the processor https://github.com/ecmwf/anemoi-core/actions/runs/20821574437/job/59811318231#step:3:3386
|





Description
This PR adds an option to the Triton GT kernel to normalise Q and K.
Currently, when QK normalisation is done in Anemoi (e.g. ensemble processor), a layernorm is computed over q and k before calling the GT. This is inefficient as it forces us to load all of Q and K twice (once for normalisation and once for computing GT.conv), and increases the memory usage as more intermediate inputs must be saved for the backward pass.
This PR fuses QK normalisation into the GT conv, by applying normalisation right after Q and K are loaded in the GT kernel. Since loading the elements from memory is the expensive part, and that is done anyway, the runtime cost of enabling QK_norm is quite low. Benchmarking results are shown below, but it is faster and uses less memory then applying a compiled RMS norm followed by the triton GT kernel.
The normalisation method implemented is RMSNorm, which is different to the LayerNorm used currently but still widely used across ML. RMSNorm was implemented because LayerNorms arent suited to the tiled memory access pattern of the GT conv (due to maintaining global mean and variance).
The RMSNorm triton kernels are split into a different file, for clarity and future reuse. The GT pytests have been extended to test the RMSNorm option