-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
While training a Unet i Encountered the following error:
ubuntu@geli-3950:/large_nfs/risk-game-ai$ ./notebooks/unet_train.py
0it [14:04, ?it/s]
Traceback (most recent call last):
File "/large_nfs/risk-game-ai/./notebooks/unet_train.py", line 173, in <module>
model, opt_state, loss_value = train_step(model, img_batch.transpose(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/equinox/equinox/_jit.py", line 209, in __call__
return _call(self, False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/equinox/equinox/_jit.py", line 263, in _call
marker, _, _ = out = jit_wrapper._cached(
^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/traceback_util.py", line 195, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/pjit.py", line 264, in cache_miss
executable, pgle_profiler, const_args) = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/pjit.py", line 146, in _python_pjit_helper
out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/pjit.py", line 1620, in _pjit_call_impl_python
compiled = computation.compile()
^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2517, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/interpreters/pxla.py", line 3063, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/interpreters/pxla.py", line 2844, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/compiler.py", line 478, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/compiler.py", line 746, in _compile_and_write_cache
executable = backend_compile_and_load(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/profiler.py", line 359, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/jax/jax/_src/compiler.py", line 372, in backend_compile_and_load
raise e
File "/opt/jax/jax/_src/compiler.py", line 362, in backend_compile_and_load
return backend.compile_and_load(
^^^^^^^^^^^^^^^^^^^^^^^^^
jax.errors.JaxRuntimeError: INTERNAL: Autotuning failed for HLO: %custom-call.149 = (f32[4,51200,51200]{2,1,0}, s8[4194304]{0}) custom-call(%div.1106, %bitcast.115), custom_call_target="__cublas$gemm", metadata={op_name="jit(train_step)/jvp(vmap(d n, d e -> n e))/dot_general" source_file="/large_nfs/risk-game-ai/jax_dist_training/train_unet.py" source_line=487 source_end_line=487 source_column=18 source_end_column=53}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"gemm_backend_config":{"alpha_real":1,"beta":0,"dot_dimension_numbers":{"lhs_contracting_dimensions":["1"],"rhs_contracting_dimensions":["1"],"lhs_batch_dimensions":["0"],"rhs_batch_dimensions":["0"]},"alpha_imag":0,"precision_config":{"operand_precision":["DEFAULT","DEFAULT"],"algorithm":"ALG_UNSET"},"epilogue":"DEFAULT","lhs_stride":"204800","rhs_stride":"204800","grad_x":false,"grad_y":false,"damax_output":false},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"} with error: NOT_FOUND: No valid config found
Is this a known issue? Happy to try and isolate a repro if not.
System info (python version, jaxlib version, accelerator, etc.)
>>> jax.print_environment_info()
W1204 06:04:41.734757 1180002 cuda_executor.cc:1801] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
W1204 06:04:41.738270 1179624 cuda_executor.cc:1801] GPU interconnect information not available: INTERNAL: NVML doesn't support extracting fabric info or NVLink is not used by the device.
jax: 0.8.2.dev20251204
jaxlib: 0.8.2.dev20251204
numpy: 2.2.6
python: 3.12.3 (main, Nov 6 2025, 13:44:16) [GCC 13.3.0]
device info: NVIDIA GeForce RTX 2060-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='geli-3950', release='6.8.0-88-generic', version='#89-Ubuntu SMP PREEMPT_DYNAMIC Sat Oct 11 01:02:46 UTC 2025', machine='x86_64')
JAX_TOOLBOX_REF=34cbc5fd5adf3911e23ed67995778cf5edb68db3
XLA_FLAGS= --xla_gpu_enable_latency_hiding_scheduler=true
$ nvidia-smi
Thu Dec 4 06:04:41 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.95.05 Driver Version: 580.95.05 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 2060 Off | 00000000:25:00.0 Off | N/A |
| 0% 50C P2 37W / 184W | 101MiB / 12288MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| No running processes found |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working