Skip to content

[XLA:GPU] Jax example produces CUDA_ERROR_INVALID_VALUE: invalid argument #33220

@justinjfu

Description

@justinjfu

Original bug: jax-ml/jax#32882

The following JAX program results in XLA:GPU generating a kernel with invalid arguments:

import jax
import jax.numpy as jnp
import numpy as np


def sink(*args, **kwargs):
    jax.experimental.io_callback(lambda *args, **kwargs: None, None, *args, **kwargs)


@jax.jit
def main(int_value, input_params, slice_index):
    def inner_fn(shape, params):
        h1_dims = (shape[0], max(shape[1], 16))
        h1 = jnp.concatenate(
            [
                jnp.ones((*h1_dims, 127)) * jnp.array([0.0])[0],
                jnp.zeros((*h1_dims, 1)),
            ],
            axis=-1,
        )
        h2 = params["p1"] * h1
        h3 = h2 / (h2**2).sum(axis=-1, keepdims=True)
        h4 = params["p2"] / (jnp.linalg.norm(params["p2"], axis=-1, keepdims=True))
        h5 = h4[:, :-1].reshape((-1,))
        h6 = h3 * (np.zeros((128,)) * h5)
        return jax.lax.dynamic_slice_in_dim(h6, slice_index, shape[1], axis=-2)

    def outer_fn_1():
        sink(inner_fn((1, 1), input_params))

    def outer_fn_2():
        def loss_fn(params):
            return inner_fn((8, 17), params).sum()

        sink(jax.value_and_grad(loss_fn)(input_params))

    def main_loop_fn(*args, **kwargs):
        jax.lax.cond(int_value != 0, outer_fn_1, outer_fn_1)
        jax.lax.cond(int_value != 0, outer_fn_2, outer_fn_2)

    jax.lax.fori_loop(0, 1, main_loop_fn, None)


main(0, {"p1": jnp.zeros((1,)), "p2": jnp.zeros((32, 5))}, jnp.array(0))

Result:

E1027 14:49:00.537894  832649 pjrt_stream_executor_client.cc:2839] Execution of replica 0 failed: INTERNAL: CUDA error: Failed to add kernel node to a CUDA graph: CUDA_ERROR_INVALID_VALUE: invalid argument
Traceback (most recent call last):
  File .../jax_bug.py", line 44, in <module>
    main(0, {"p1": jnp.zeros((1,)), "p2": jnp.ones((32, 5))}, jnp.array(0))
    ~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: Failed to add kernel node to a CUDA graph: CUDA_ERROR_INVALID_VALUE: invalid argument
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

System Information:

jax:    0.8.0
jaxlib: 0.8.0
numpy:  2.3.4
python: 3.13.9 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 19:16:10) [GCC 11.2.0]
device info: NVIDIA GeForce RTX 3080 Ti-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='...', release='6.8.0-85-generic', version='#85~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Fri Sep 19 16:18:59 UTC 2', machine='x86_64')
$ nvidia-smi
Mon Oct 27 14:45:38 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+

Metadata

Metadata

Labels

GPUXLA on GPU

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions