-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
Summary:
On WSL2, after running inference with a loaded checkpoint, JAX array transfers from device to host fail with CUDA_ERROR_OUT_OF_MEMORY, despite:
- Tuning XLA flags like preallocate / mem_fraction / allocator
- Sufficient VRAM available (we can safely triple the batch size without CUDA OOM during inference)
- Successful completion of
block_until_ready() - Valid device pointers accessible via
__cuda_array_interface__
This happens with np.asarray, jax.device_get, and jax.device_put(..., cpu_device).
Repro sketch:
import jax.numpy as jnp
import numpy as np
# I'm loading and running a model like so:
params = load_checkpoint("model.npz") # Calls jnp.asarray to transfer to GPU
result = model(params, batch) # Succeeds, and we see feedback with callbacks
result.block_until_ready() # Succeeds
# Any transfer attempt fails
output = np.asarray(result) # CUDA_ERROR_OUT_OF_MEMORY
# Even this fails
output = result[0, 0, 0, 0].item() # CUDA_ERROR_OUT_OF_MEMORY
# Subsequent unrelated transfers fail
fresh = jnp.ones(10)
np.asarray(fresh) # CUDA_ERROR_OUT_OF_MEMORYMy handwavy guess at the root cause:
Either the checkpoint loading or inference operation corrupts JAX's host-device transfer path on WSL2, while leaving the CUDA context functional for compute. Direct cudaMemcpy works fine, suggesting this is specific to JAX's transfer implementation interacting badly with WSL2's virtualized CUDA driver.
Workaround:
import ctypes
import numpy as np
cudart = ctypes.CDLL("libcudart.so")
def direct_cuda_transfer(jax_array):
jax_array.block_until_ready()
ptr = jax_array.__cuda_array_interface__["data"][0]
host_array = np.empty(jax_array.shape, dtype=jax_array.dtype)
error = cudart.cudaMemcpy(
ctypes.c_void_p(host_array.ctypes.data),
ctypes.c_void_p(ptr),
ctypes.c_size_t(jax_array.nbytes),
ctypes.c_int(2) # cudaMemcpyDeviceToHost
)
if error != 0:
raise RuntimeError(f"cudaMemcpy failed with code {error}")
return host_array
# Works reliably
output = direct_cuda_transfer(result)System info (python version, jaxlib version, accelerator, etc.)
- OS: Ubuntu 24.04 running in WSL 2.6.1.0 on Windows 10
- GPU: Nvidia 3090
- CUDA driver: 581.80
- Python: 3.12.3
- jax[cuda12]==0.8.1
- jaxlib==0.8.1
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working