-
Notifications
You must be signed in to change notification settings - Fork 662
Open
Labels
err: RuntimeRuntime ErrorRuntime Error
Description
Description
jax.nn.dot_product_attention fails on cudnn backend in JAX 0.7.2 CUDA 13 with XlaRuntimeError. I've reported this bug in the JAX repo (this issue), and one JAX dev told me I should report it here.
Here's a quick repro:
import jax
import jax.numpy as jnp
qkv = jax.random.normal(jax.random.PRNGKey(42), (2, 8192, 4, 128), dtype=jnp.bfloat16)
flash_attention_output = jax.nn.dot_product_attention(
query=qkv, key=qkv, value=qkv,
implementation='cudnn',
is_causal=True,
)
This piece of code produces the following error:
Traceback (most recent call last):
File "/nfs/nest/program/scalax/cuda13/repro.py", line 5, in <module>
flash_attention_output = jax.nn.dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/nn/functions.py", line 1238, in dot_product_attention
out = cudnn_dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 2060, in dot_product_attention
output = _dot_product_attention(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 1219, in _dot_product_attention
output = _dot_product_attention_fwd(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 443, in _dot_product_attention_fwd
outputs = _dot_product_attention_fwd_p_wrapper.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/young/miniforge3/envs/jax_cuda13/lib/python3.12/site-packages/jax/_src/cudnn/fused_attention_stablehlo.py", line 552, in _dot_product_attention_fwd_impl
outputs = _dot_product_attention_fwd_p.bind(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: [cudnn_frontend] Error: No valid execution plans built.
in external/xla/xla/stream_executor/cuda/cuda_dnn.cc(6864): 'graph_.build_plans(cudnn_handle)'
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
If I switch the implementation
to xla
the code works fine, and the code also works if I install jax[cuda12]
. This bug is reproducible on both Hopper and Blackwell GPUs (H100, B200, RTX5090).
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.3.3
python: 3.12.11 | packaged by conda-forge | (main, Jun 4 2025, 14:45:31) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 5090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='glados', release='6.8.0-85-generic', version='#85-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 18 15:26:59 UTC 2025', machine='x86_64')
XLA_PYTHON_CLIENT_PREALLOCATE=false
$ nvidia-smi
Tue Oct 7 22:17:38 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 5090 On | 00000000:02:00.0 Off | N/A |
| 0% 47C P1 23W / 575W | 538MiB / 32607MiB | 2% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 116270 C python 528MiB |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
err: RuntimeRuntime ErrorRuntime Error