-
Notifications
You must be signed in to change notification settings - Fork 696
Open
Labels
GPUXLA on GPUXLA on GPU
Description
Original bug: jax-ml/jax#32882
The following JAX program results in XLA:GPU generating a kernel with invalid arguments:
import jax
import jax.numpy as jnp
import numpy as np
def sink(*args, **kwargs):
jax.experimental.io_callback(lambda *args, **kwargs: None, None, *args, **kwargs)
@jax.jit
def main(int_value, input_params, slice_index):
def inner_fn(shape, params):
h1_dims = (shape[0], max(shape[1], 16))
h1 = jnp.concatenate(
[
jnp.ones((*h1_dims, 127)) * jnp.array([0.0])[0],
jnp.zeros((*h1_dims, 1)),
],
axis=-1,
)
h2 = params["p1"] * h1
h3 = h2 / (h2**2).sum(axis=-1, keepdims=True)
h4 = params["p2"] / (jnp.linalg.norm(params["p2"], axis=-1, keepdims=True))
h5 = h4[:, :-1].reshape((-1,))
h6 = h3 * (np.zeros((128,)) * h5)
return jax.lax.dynamic_slice_in_dim(h6, slice_index, shape[1], axis=-2)
def outer_fn_1():
sink(inner_fn((1, 1), input_params))
def outer_fn_2():
def loss_fn(params):
return inner_fn((8, 17), params).sum()
sink(jax.value_and_grad(loss_fn)(input_params))
def main_loop_fn(*args, **kwargs):
jax.lax.cond(int_value != 0, outer_fn_1, outer_fn_1)
jax.lax.cond(int_value != 0, outer_fn_2, outer_fn_2)
jax.lax.fori_loop(0, 1, main_loop_fn, None)
main(0, {"p1": jnp.zeros((1,)), "p2": jnp.zeros((32, 5))}, jnp.array(0))
Result:
E1027 14:49:00.537894 832649 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: INTERNAL: CUDA error: Failed to add kernel node to a CUDA graph: CUDA_ERROR_INVALID_VALUE: invalid argument
Traceback (most recent call last):
File .../jax_bug.py", line 44, in <module>
main(0, {"p1": jnp.zeros((1,)), "p2": jnp.ones((32, 5))}, jnp.array(0))
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: Failed to add kernel node to a CUDA graph: CUDA_ERROR_INVALID_VALUE: invalid argument
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
System Information:
jax: 0.8.0
jaxlib: 0.8.0
numpy: 2.3.4
python: 3.13.9 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 19:16:10) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 3080 Ti-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='6.8.0-85-generic', version='#85~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 16:18:59 UTC 2', machine='x86_64')
$ nvidia-smi
Mon Oct 27 14:45:38 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
Metadata
Metadata
Assignees
Labels
GPUXLA on GPUXLA on GPU