Skip to content

Commit 14c72cc

Browse files
committed
Use path_finder.load_nvidia_dynamic_library("nvrtc") from cuda/bindings/_bindings/cynvrtc.pyx.in
1 parent d12cbf5 commit 14c72cc

File tree

2 files changed

+8
-62
lines changed

2 files changed

+8
-62
lines changed

Diff for: cuda_bindings/cuda/bindings/_bindings/cynvrtc.pyx.in

+6-62
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
# This code was automatically generated with version 12.8.0. Do not modify it directly.
1010
{{if 'Windows' == platform.system()}}
1111
import os
12-
import site
13-
import struct
1412
import win32api
15-
from pywintypes import error
1613
{{else}}
1714
cimport cuda.bindings._lib.dlfcn as dlfcn
15+
from libc.stdint cimport uintptr_t
1816
{{endif}}
17+
from cuda.bindings import path_finder
1918

2019
cdef bint __cuPythonInit = False
2120
{{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
@@ -47,74 +46,17 @@ cdef bint __cuPythonInit = False
4746

4847
cdef int cuPythonInit() except -1 nogil:
4948
{{if 'Windows' != platform.system()}}
50-
cdef char* err_msg
49+
cdef void* handle = NULL
5150
{{endif}}
5251

5352
global __cuPythonInit
5453
if __cuPythonInit:
5554
return 0
5655
__cuPythonInit = True
5756

58-
# Load library
59-
{{if 'Windows' == platform.system()}}
60-
with gil:
61-
# First check if the DLL has been loaded by 3rd parties
62-
try:
63-
handle = win32api.GetModuleHandle("nvrtc64_120_0.dll")
64-
except:
65-
handle = None
66-
67-
# Else try default search
68-
if not handle:
69-
LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000
70-
try:
71-
handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS)
72-
except:
73-
pass
74-
75-
# Final check if DLLs can be found within pip installations
76-
if not handle:
77-
site_packages = [site.getusersitepackages()] + site.getsitepackages()
78-
for sp in site_packages:
79-
mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin")
80-
if not os.path.isdir(mod_path):
81-
continue
82-
os.add_dll_directory(mod_path)
83-
LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000
84-
LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100
85-
try:
86-
handle = win32api.LoadLibraryEx(
87-
# Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path...
88-
os.path.join(mod_path, "nvrtc64_120_0.dll"),
89-
0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)
90-
91-
# Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is
92-
# located in the same mod_path.
93-
# Update PATH environ so that the two dlls can find each other
94-
os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path))
95-
except:
96-
pass
97-
98-
if not handle:
99-
raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll')
100-
{{else}}
101-
with gil:
102-
print("\nLOOOK dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)", flush=True)
103-
handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)
104-
if handle == NULL:
105-
with gil:
106-
err_msg = dlfcn.dlerror()
107-
if err_msg == NULL:
108-
err_msg_str = 'Unknown error'
109-
else:
110-
err_msg_str = err_msg.decode('utf-8', errors='backslashreplace')
111-
raise RuntimeError(f'Failed to dlopen libnvrtc.so.12: {err_msg_str}')
112-
{{endif}}
113-
114-
115-
# Load function
11657
{{if 'Windows' == platform.system()}}
11758
with gil:
59+
handle = path_finder.load_nvidia_dynamic_library("nvrtc")
11860
{{if 'nvrtcGetErrorString' in found_functions}}
11961
try:
12062
global __nvrtcGetErrorString
@@ -299,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil:
299241
{{endif}}
300242

301243
{{else}}
244+
with gil:
245+
handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc")
302246
{{if 'nvrtcGetErrorString' in found_functions}}
303247
global __nvrtcGetErrorString
304248
__nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')

Diff for: cuda_bindings/cuda/bindings/_path_finder/load_nvidia_dynamic_library.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def _windows_load_with_dll_basename(name: str) -> int:
4545

4646
if name == "nvJitLink":
4747
dll_name = "nvJitLink_120_0.dll"
48+
elif name == "nvrtc":
49+
dll_name = "nvrtc64_120_0.dll"
4850
elif name == "nvvm":
4951
dll_name = "nvvm64_40_0.dll"
5052

0 commit comments

Comments
 (0)