|
9 | 9 | # This code was automatically generated with version 12.8.0. Do not modify it directly.
|
10 | 10 | {{if 'Windows' == platform.system()}}
|
11 | 11 | import os
|
12 |
| -import site |
13 |
| -import struct |
14 | 12 | import win32api
|
15 |
| -from pywintypes import error |
16 | 13 | {{else}}
|
17 | 14 | cimport cuda.bindings._lib.dlfcn as dlfcn
|
| 15 | +from libc.stdint cimport uintptr_t |
18 | 16 | {{endif}}
|
| 17 | +from cuda.bindings import path_finder |
19 | 18 |
|
20 | 19 | cdef bint __cuPythonInit = False
|
21 | 20 | {{if 'nvrtcGetErrorString' in found_functions}}cdef void *__nvrtcGetErrorString = NULL{{endif}}
|
@@ -47,74 +46,17 @@ cdef bint __cuPythonInit = False
|
47 | 46 |
|
48 | 47 | cdef int cuPythonInit() except -1 nogil:
|
49 | 48 | {{if 'Windows' != platform.system()}}
|
50 |
| - cdef char* err_msg |
| 49 | + cdef void* handle = NULL |
51 | 50 | {{endif}}
|
52 | 51 |
|
53 | 52 | global __cuPythonInit
|
54 | 53 | if __cuPythonInit:
|
55 | 54 | return 0
|
56 | 55 | __cuPythonInit = True
|
57 | 56 |
|
58 |
| - # Load library |
59 |
| - {{if 'Windows' == platform.system()}} |
60 |
| - with gil: |
61 |
| - # First check if the DLL has been loaded by 3rd parties |
62 |
| - try: |
63 |
| - handle = win32api.GetModuleHandle("nvrtc64_120_0.dll") |
64 |
| - except: |
65 |
| - handle = None |
66 |
| - |
67 |
| - # Else try default search |
68 |
| - if not handle: |
69 |
| - LOAD_LIBRARY_SAFE_CURRENT_DIRS = 0x00002000 |
70 |
| - try: |
71 |
| - handle = win32api.LoadLibraryEx("nvrtc64_120_0.dll", 0, LOAD_LIBRARY_SAFE_CURRENT_DIRS) |
72 |
| - except: |
73 |
| - pass |
74 |
| - |
75 |
| - # Final check if DLLs can be found within pip installations |
76 |
| - if not handle: |
77 |
| - site_packages = [site.getusersitepackages()] + site.getsitepackages() |
78 |
| - for sp in site_packages: |
79 |
| - mod_path = os.path.join(sp, "nvidia", "cuda_nvrtc", "bin") |
80 |
| - if not os.path.isdir(mod_path): |
81 |
| - continue |
82 |
| - os.add_dll_directory(mod_path) |
83 |
| - LOAD_LIBRARY_SEARCH_DEFAULT_DIRS = 0x00001000 |
84 |
| - LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR = 0x00000100 |
85 |
| - try: |
86 |
| - handle = win32api.LoadLibraryEx( |
87 |
| - # Note: LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR needs an abs path... |
88 |
| - os.path.join(mod_path, "nvrtc64_120_0.dll"), |
89 |
| - 0, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR) |
90 |
| - |
91 |
| - # Note: nvrtc64_120_0.dll calls into nvrtc-builtins64_*.dll which is |
92 |
| - # located in the same mod_path. |
93 |
| - # Update PATH environ so that the two dlls can find each other |
94 |
| - os.environ["PATH"] = os.pathsep.join((os.environ.get("PATH", ""), mod_path)) |
95 |
| - except: |
96 |
| - pass |
97 |
| - |
98 |
| - if not handle: |
99 |
| - raise RuntimeError('Failed to LoadLibraryEx nvrtc64_120_0.dll') |
100 |
| - {{else}} |
101 |
| - with gil: |
102 |
| - print("\nLOOOK dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW)", flush=True) |
103 |
| - handle = dlfcn.dlopen('libnvrtc.so.12', dlfcn.RTLD_NOW) |
104 |
| - if handle == NULL: |
105 |
| - with gil: |
106 |
| - err_msg = dlfcn.dlerror() |
107 |
| - if err_msg == NULL: |
108 |
| - err_msg_str = 'Unknown error' |
109 |
| - else: |
110 |
| - err_msg_str = err_msg.decode('utf-8', errors='backslashreplace') |
111 |
| - raise RuntimeError(f'Failed to dlopen libnvrtc.so.12: {err_msg_str}') |
112 |
| - {{endif}} |
113 |
| - |
114 |
| - |
115 |
| - # Load function |
116 | 57 | {{if 'Windows' == platform.system()}}
|
117 | 58 | with gil:
|
| 59 | + handle = path_finder.load_nvidia_dynamic_library("nvrtc") |
118 | 60 | {{if 'nvrtcGetErrorString' in found_functions}}
|
119 | 61 | try:
|
120 | 62 | global __nvrtcGetErrorString
|
@@ -299,6 +241,8 @@ cdef int cuPythonInit() except -1 nogil:
|
299 | 241 | {{endif}}
|
300 | 242 |
|
301 | 243 | {{else}}
|
| 244 | + with gil: |
| 245 | + handle = <void*><uintptr_t>path_finder.load_nvidia_dynamic_library("nvrtc") |
302 | 246 | {{if 'nvrtcGetErrorString' in found_functions}}
|
303 | 247 | global __nvrtcGetErrorString
|
304 | 248 | __nvrtcGetErrorString = dlfcn.dlsym(handle, 'nvrtcGetErrorString')
|
|
0 commit comments