Skip to content

Device -> host transfer causes CUDA OOM on WSL2 #33694

@jkyl

Description

@jkyl

Description

Summary:

On WSL2, after running inference with a loaded checkpoint, JAX array transfers from device to host fail with CUDA_ERROR_OUT_OF_MEMORY, despite:

  • Tuning XLA flags like preallocate / mem_fraction / allocator
  • Sufficient VRAM available (we can safely triple the batch size without CUDA OOM during inference)
  • Successful completion of block_until_ready()
  • Valid device pointers accessible via __cuda_array_interface__

This happens with np.asarray, jax.device_get, and jax.device_put(..., cpu_device).

Repro sketch:

import jax.numpy as jnp
import numpy as np

# I'm loading and running a model like so:
params = load_checkpoint("model.npz")  # Calls jnp.asarray to transfer to GPU
result = model(params, batch)  # Succeeds, and we see feedback with callbacks
result.block_until_ready()  # Succeeds

# Any transfer attempt fails
output = np.asarray(result)  # CUDA_ERROR_OUT_OF_MEMORY

# Even this fails
output = result[0, 0, 0, 0].item()  # CUDA_ERROR_OUT_OF_MEMORY

# Subsequent unrelated transfers fail
fresh = jnp.ones(10)
np.asarray(fresh)  # CUDA_ERROR_OUT_OF_MEMORY

My handwavy guess at the root cause:

Either the checkpoint loading or inference operation corrupts JAX's host-device transfer path on WSL2, while leaving the CUDA context functional for compute. Direct cudaMemcpy works fine, suggesting this is specific to JAX's transfer implementation interacting badly with WSL2's virtualized CUDA driver.

Workaround:

import ctypes
import numpy as np

cudart = ctypes.CDLL("libcudart.so")

def direct_cuda_transfer(jax_array):
    jax_array.block_until_ready()
    
    ptr = jax_array.__cuda_array_interface__["data"][0]
    host_array = np.empty(jax_array.shape, dtype=jax_array.dtype)
    
    error = cudart.cudaMemcpy(
        ctypes.c_void_p(host_array.ctypes.data),
        ctypes.c_void_p(ptr),
        ctypes.c_size_t(jax_array.nbytes),
        ctypes.c_int(2)  # cudaMemcpyDeviceToHost
    )
    
    if error != 0:
        raise RuntimeError(f"cudaMemcpy failed with code {error}")
    
    return host_array

# Works reliably
output = direct_cuda_transfer(result)

System info (python version, jaxlib version, accelerator, etc.)

  • OS: Ubuntu 24.04 running in WSL 2.6.1.0 on Windows 10
  • GPU: Nvidia 3090
  • CUDA driver: 581.80
  • Python: 3.12.3
    • jax[cuda12]==0.8.1
    • jaxlib==0.8.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions