Skip to content

Commit 29be596

Browse files
authored
fix: undefined symbol cudaGetDriverEntryPointByVersion with CUDA >= 12.5 (flashinfer-ai#928)
## Problem: When ① build flashinfer with CUDA >= 12.5 (using system-wide CUDA toolkit under `/usr/local/cuda`), and ② run with CUDA < 12.5 (using `libcudart.so` under the python environment `/usr/local/lib/python3.10/dist-packages/nvidia/cuda_runtime/lib/libcudart.so.12`), one would meet the issue of undefined symbol `cudaGetDriverEntryPointByVersion`, which is introduced since CUDA 12.5. <img width="824" alt="image" src="https://github.com/user-attachments/assets/30322352-2cdc-45b5-adc3-2eb82fbac45e" /> This issue has been reported and fixed in other projects: - cutlass: NVIDIA/cutlass#2086 - sglang: sgl-project/sglang#3372 ## Fix This fix is a workaround of this issue which forces flashinfer use system-wide CUDA toolkit, refer to the fix in [sglang](sgl-project/sglang#3372), cc @zhyncs.
1 parent 1e2515e commit 29be596

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

flashinfer/jit/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@
5252
from .env import *
5353
from .utils import parallel_load_modules as parallel_load_modules
5454

55+
56+
import os
57+
import ctypes
58+
cuda_lib_path = os.environ.get('CUDA_LIB_PATH', '/usr/local/cuda/targets/x86_64-linux/lib/')
59+
if os.path.exists(f"{cuda_lib_path}/libcudart.so.12"):
60+
ctypes.CDLL(f"{cuda_lib_path}/libcudart.so.12", mode=ctypes.RTLD_GLOBAL)
61+
62+
5563
try:
5664
from .. import flashinfer_kernels, flashinfer_kernels_sm90 # noqa: F401
5765
from .aot_config import prebuilt_ops_uri as prebuilt_ops_uri

0 commit comments

Comments
 (0)