Skip to content

jax.errors.JaxRuntimeError: INTERNAL: Autotuning failed for HLO #33715

@geligeli

Description

@geligeli

Description

While training a Unet i Encountered the following error:

ubuntu@geli-3950:/large_nfs/risk-game-ai$ ./notebooks/unet_train.py 
0it [14:04, ?it/s]
Traceback (most recent call last):
  File "/large_nfs/risk-game-ai/./notebooks/unet_train.py", line 173, in <module>
    model, opt_state, loss_value = train_step(model, img_batch.transpose(
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/equinox/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/equinox/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/traceback_util.py", line 195, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/pjit.py", line 264, in cache_miss
    executable, pgle_profiler, const_args) = _python_pjit_helper(
                                             ^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/pjit.py", line 146, in _python_pjit_helper
    out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
                                               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/pjit.py", line 1620, in _pjit_call_impl_python
    compiled = computation.compile()
               ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2517, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 3063, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/interpreters/pxla.py", line 2844, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/compiler.py", line 478, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/compiler.py", line 746, in _compile_and_write_cache
    executable = backend_compile_and_load(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/profiler.py", line 359, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/opt/jax/jax/_src/compiler.py", line 372, in backend_compile_and_load
    raise e
  File "/opt/jax/jax/_src/compiler.py", line 362, in backend_compile_and_load
    return backend.compile_and_load(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.JaxRuntimeError: INTERNAL: Autotuning failed for HLO: %custom-call.149 = (f32[4,51200,51200]{2,1,0}, s8[4194304]{0}) custom-call(%div.1106, %bitcast.115), custom_call_target="__cublas$gemm", metadata={op_name="jit(train_step)/jvp(vmap(d n, d e -> n e))/dot_general" source_file="/large_nfs/risk-game-ai/jax_dist_training/train_unet.py" source_line=487 source_end_line=487 source_column=18 source_end_column=53}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0"],"rhs_batch_dimensions":["0"]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","lhs_stride":"204800","rhs_stride":"204800","grad_x":false,"grad_y":false,"damax_output":false},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"} with error: NOT_FOUND: No valid config found

Is this a known issue? Happy to try and isolate a repro if not.

System info (python version, jaxlib version, accelerator, etc.)

>>> jax.print_environment_info()
W1204 06:04:41.734757 1180002 cuda_executor.cc:1801] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1204 06:04:41.738270 1179624 cuda_executor.cc:1801] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
jax:    0.8.2.dev20251204
jaxlib: 0.8.2.dev20251204
numpy:  2.2.6
python: 3.12.3 (main, Nov  6 2025, 13:44:16) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 2060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='geli-3950', release='6.8.0-88-generic', version='#89-Ubuntu SMP PREEMPT_DYNAMIC Sat Oct 11 01:02:46 UTC 2025', machine='x86_64')
JAX_TOOLBOX_REF=34cbc5fd5adf3911e23ed67995778cf5edb68db3
XLA_FLAGS= --xla_gpu_enable_latency_hiding_scheduler=true

$ nvidia-smi
Thu Dec  4 06:04:41 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05              Driver Version: 580.95.05      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 2060        Off |   00000000:25:00.0 Off |                  N/A |
|  0%   50C    P2             37W /  184W |     101MiB /  12288MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions