Skip to content

CUDNN jax.nn.dot_product_attention fails on CUDA 13 #32385

@young-geng

Description

@young-geng

Description

jax.nn.dot_product_attention fails on cudnn backend in JAX 0.7.2 CUDA 13 with XlaRuntimeError. I've reported this bug in the JAX repo (this issue), and one JAX dev told me I should report it here.

Here's a quick repro:

import jax
import jax.numpy as jnp

qkv = jax.random.normal(jax.random.PRNGKey(42), (2, 8192, 4, 128), dtype=jnp.bfloat16)
flash_attention_output = jax.nn.dot_product_attention(
    query=qkv, key=qkv, value=qkv,
    implementation='cudnn',
    is_causal=True,
)

This piece of code produces the following error:

Traceback (most recent call last):
  File "/nfs/nest/program/scalax/cuda13/repro.py", line 5, in <module>
    flash_attention_output = jax.nn.dot_product_attention(
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/nn/functions.py", line 1238, in dot_product_attention
    out = cudnn_dot_product_attention(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 2060, in dot_product_attention
    output = _dot_product_attention(
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 1219, in _dot_product_attention
    output = _dot_product_attention_fwd(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 443, in _dot_product_attention_fwd
    outputs = _dot_product_attention_fwd_p_wrapper.bind(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 552, in _dot_product_attention_fwd_impl
    outputs = _dot_product_attention_fwd_p.bind(
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: [cudnn_frontend] Error: No valid execution plans built.
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6864): 'graph_.build_plans(cudnn_handle)'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

If I switch the implementation to xla the code works fine, and the code also works if I install jax[cuda12]. This bug is reproducible on both Hopper and Blackwell GPUs (H100, B200, RTX5090).

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.3.3
python: 3.12.11 | packaged by conda-forge | (main, Jun  4 2025, 14:45:31) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 5090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='glados', release='6.8.0-85-generic', version='#85-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 18 15:26:59 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_PREALLOCATE=false

$ nvidia-smi
Tue Oct  7 22:17:38 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 5090        On  |   00000000:02:00.0 Off |                  N/A |
|  0%   47C    P1             23W /  575W |     538MiB /  32607MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A          116270      C   python                                  528MiB |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions