Skip to content

Commit ed0ebb3

Browse files
committed
ruff format NO MANUAL CHANGES
1 parent d31920c commit ed0ebb3

File tree

1 file changed

+78
-97
lines changed

1 file changed

+78
-97
lines changed

cuda_bindings/cuda/bindings/ecosystem/cuda_paths.py

Lines changed: 78 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
1-
import sys
2-
import re
31
import os
4-
from collections import namedtuple
52
import platform
3+
import re
64
import site
5+
import sys
6+
from collections import namedtuple
77
from pathlib import Path
8-
from numba.core.config import IS_WIN32
9-
from numba.misc.findlib import find_lib, find_file
10-
from numba import config
118

9+
from numba import config
10+
from numba.core.config import IS_WIN32
11+
from numba.misc.findlib import find_file, find_lib
1212

13-
_env_path_tuple = namedtuple('_env_path_tuple', ['by', 'info'])
13+
_env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
1414

1515

1616
def _find_valid_path(options):
@@ -22,52 +22,50 @@ def _find_valid_path(options):
2222
if data is not None:
2323
return by, data
2424
else:
25-
return '<unknown>', None
25+
return "<unknown>", None
2626

2727

2828
def _get_libdevice_path_decision():
2929
options = [
30-
('Conda environment', get_conda_ctk()),
31-
('Conda environment (NVIDIA package)', get_nvidia_libdevice_ctk()),
32-
('CUDA_HOME', get_cuda_home('nvvm', 'libdevice')),
33-
('Debian package', get_debian_pkg_libdevice()),
34-
('NVIDIA NVCC Wheel', get_libdevice_wheel()),
30+
("Conda environment", get_conda_ctk()),
31+
("Conda environment (NVIDIA package)", get_nvidia_libdevice_ctk()),
32+
("CUDA_HOME", get_cuda_home("nvvm", "libdevice")),
33+
("Debian package", get_debian_pkg_libdevice()),
34+
("NVIDIA NVCC Wheel", get_libdevice_wheel()),
3535
]
36-
libdevice_ctk_dir = get_system_ctk('nvvm', 'libdevice')
36+
libdevice_ctk_dir = get_system_ctk("nvvm", "libdevice")
3737
if os.path.exists(libdevice_ctk_dir):
38-
options.append(('System', libdevice_ctk_dir))
38+
options.append(("System", libdevice_ctk_dir))
3939

4040
by, libdir = _find_valid_path(options)
4141
return by, libdir
4242

4343

4444
def _nvvm_lib_dir():
4545
if IS_WIN32:
46-
return 'nvvm', 'bin'
46+
return "nvvm", "bin"
4747
else:
48-
return 'nvvm', 'lib64'
48+
return "nvvm", "lib64"
4949

5050

5151
def _get_nvvm_path_decision():
5252
options = [
53-
('Conda environment', get_conda_ctk()),
54-
('Conda environment (NVIDIA package)', get_nvidia_nvvm_ctk()),
55-
('CUDA_HOME', get_cuda_home(*_nvvm_lib_dir())),
56-
('NVIDIA NVCC Wheel', _get_nvvm_wheel()),
53+
("Conda environment", get_conda_ctk()),
54+
("Conda environment (NVIDIA package)", get_nvidia_nvvm_ctk()),
55+
("CUDA_HOME", get_cuda_home(*_nvvm_lib_dir())),
56+
("NVIDIA NVCC Wheel", _get_nvvm_wheel()),
5757
]
5858
# need to ensure nvvm dir actually exists
5959
nvvm_ctk_dir = get_system_ctk(*_nvvm_lib_dir())
6060
if os.path.exists(nvvm_ctk_dir):
61-
options.append(('System', nvvm_ctk_dir))
61+
options.append(("System", nvvm_ctk_dir))
6262

6363
by, path = _find_valid_path(options)
6464
return by, path
6565

6666

6767
def _get_nvvm_wheel():
68-
site_paths = [
69-
site.getusersitepackages()
70-
] + site.getsitepackages() + ["conda", None]
68+
site_paths = [site.getusersitepackages()] + site.getsitepackages() + ["conda", None]
7169
for sp in site_paths:
7270
# The SONAME is taken based on public CTK 12.x releases
7371
if sys.platform.startswith("linux"):
@@ -82,13 +80,7 @@ def _get_nvvm_wheel():
8280
raise AssertionError()
8381

8482
if sp is not None:
85-
dso_dir = os.path.join(
86-
sp,
87-
"nvidia",
88-
"cuda_nvcc",
89-
"nvvm",
90-
dso_dir
91-
)
83+
dso_dir = os.path.join(sp, "nvidia", "cuda_nvcc", "nvvm", dso_dir)
9284
dso_path = os.path.join(dso_dir, dso_path)
9385
if os.path.exists(dso_path):
9486
return str(Path(dso_path).parent)
@@ -101,7 +93,7 @@ def _get_libdevice_paths():
10193
out = os.path.join(libdir, "libdevice.10.bc")
10294
else:
10395
# Search for pattern
104-
pat = r'libdevice(\.\d+)*\.bc$'
96+
pat = r"libdevice(\.\d+)*\.bc$"
10597
candidates = find_file(re.compile(pat), libdir)
10698
# Keep only the max (most recent version) of the bitcode files.
10799
out = max(candidates, default=None)
@@ -110,35 +102,35 @@ def _get_libdevice_paths():
110102

111103
def _cudalib_path():
112104
if IS_WIN32:
113-
return 'bin'
105+
return "bin"
114106
else:
115-
return 'lib64'
107+
return "lib64"
116108

117109

118110
def _cuda_home_static_cudalib_path():
119111
if IS_WIN32:
120-
return ('lib', 'x64')
112+
return ("lib", "x64")
121113
else:
122-
return ('lib64',)
114+
return ("lib64",)
123115

124116

125117
def _get_cudalib_dir_path_decision():
126118
options = [
127-
('Conda environment', get_conda_ctk()),
128-
('Conda environment (NVIDIA package)', get_nvidia_cudalib_ctk()),
129-
('CUDA_HOME', get_cuda_home(_cudalib_path())),
130-
('System', get_system_ctk(_cudalib_path())),
119+
("Conda environment", get_conda_ctk()),
120+
("Conda environment (NVIDIA package)", get_nvidia_cudalib_ctk()),
121+
("CUDA_HOME", get_cuda_home(_cudalib_path())),
122+
("System", get_system_ctk(_cudalib_path())),
131123
]
132124
by, libdir = _find_valid_path(options)
133125
return by, libdir
134126

135127

136128
def _get_static_cudalib_dir_path_decision():
137129
options = [
138-
('Conda environment', get_conda_ctk()),
139-
('Conda environment (NVIDIA package)', get_nvidia_static_cudalib_ctk()),
140-
('CUDA_HOME', get_cuda_home(*_cuda_home_static_cudalib_path())),
141-
('System', get_system_ctk(_cudalib_path())),
130+
("Conda environment", get_conda_ctk()),
131+
("Conda environment (NVIDIA package)", get_nvidia_static_cudalib_ctk()),
132+
("CUDA_HOME", get_cuda_home(*_cuda_home_static_cudalib_path())),
133+
("System", get_system_ctk(_cudalib_path())),
142134
]
143135
by, libdir = _find_valid_path(options)
144136
return by, libdir
@@ -155,92 +147,86 @@ def _get_static_cudalib_dir():
155147

156148

157149
def get_system_ctk(*subdirs):
158-
"""Return path to system-wide cudatoolkit; or, None if it doesn't exist.
159-
"""
150+
"""Return path to system-wide cudatoolkit; or, None if it doesn't exist."""
160151
# Linux?
161-
if sys.platform.startswith('linux'):
152+
if sys.platform.startswith("linux"):
162153
# Is cuda alias to /usr/local/cuda?
163154
# We are intentionally not getting versioned cuda installation.
164-
base = '/usr/local/cuda'
155+
base = "/usr/local/cuda"
165156
if os.path.exists(base):
166157
return os.path.join(base, *subdirs)
167158

168159

169160
def get_conda_ctk():
170-
"""Return path to directory containing the shared libraries of cudatoolkit.
171-
"""
172-
is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
161+
"""Return path to directory containing the shared libraries of cudatoolkit."""
162+
is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
173163
if not is_conda_env:
174164
return
175165
# Assume the existence of NVVM to imply cudatoolkit installed
176-
paths = find_lib('nvvm')
166+
paths = find_lib("nvvm")
177167
if not paths:
178168
return
179169
# Use the directory name of the max path
180170
return os.path.dirname(max(paths))
181171

182172

183173
def get_nvidia_nvvm_ctk():
184-
"""Return path to directory containing the NVVM shared library.
185-
"""
186-
is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
174+
"""Return path to directory containing the NVVM shared library."""
175+
is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
187176
if not is_conda_env:
188177
return
189178

190179
# Assume the existence of NVVM in the conda env implies that a CUDA toolkit
191180
# conda package is installed.
192181

193182
# First, try the location used on Linux and the Windows 11.x packages
194-
libdir = os.path.join(sys.prefix, 'nvvm', _cudalib_path())
183+
libdir = os.path.join(sys.prefix, "nvvm", _cudalib_path())
195184
if not os.path.exists(libdir) or not os.path.isdir(libdir):
196185
# If that fails, try the location used for Windows 12.x packages
197-
libdir = os.path.join(sys.prefix, 'Library', 'nvvm', _cudalib_path())
186+
libdir = os.path.join(sys.prefix, "Library", "nvvm", _cudalib_path())
198187
if not os.path.exists(libdir) or not os.path.isdir(libdir):
199188
# If that doesn't exist either, assume we don't have the NVIDIA
200189
# conda package
201190
return
202191

203-
paths = find_lib('nvvm', libdir=libdir)
192+
paths = find_lib("nvvm", libdir=libdir)
204193
if not paths:
205194
return
206195
# Use the directory name of the max path
207196
return os.path.dirname(max(paths))
208197

209198

210199
def get_nvidia_libdevice_ctk():
211-
"""Return path to directory containing the libdevice library.
212-
"""
200+
"""Return path to directory containing the libdevice library."""
213201
nvvm_ctk = get_nvidia_nvvm_ctk()
214202
if not nvvm_ctk:
215203
return
216204
nvvm_dir = os.path.dirname(nvvm_ctk)
217-
return os.path.join(nvvm_dir, 'libdevice')
205+
return os.path.join(nvvm_dir, "libdevice")
218206

219207

220208
def get_nvidia_cudalib_ctk():
221-
"""Return path to directory containing the shared libraries of cudatoolkit.
222-
"""
209+
"""Return path to directory containing the shared libraries of cudatoolkit."""
223210
nvvm_ctk = get_nvidia_nvvm_ctk()
224211
if not nvvm_ctk:
225212
return
226213
env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
227-
subdir = 'bin' if IS_WIN32 else 'lib'
214+
subdir = "bin" if IS_WIN32 else "lib"
228215
return os.path.join(env_dir, subdir)
229216

230217

231218
def get_nvidia_static_cudalib_ctk():
232-
"""Return path to directory containing the static libraries of cudatoolkit.
233-
"""
219+
"""Return path to directory containing the static libraries of cudatoolkit."""
234220
nvvm_ctk = get_nvidia_nvvm_ctk()
235221
if not nvvm_ctk:
236222
return
237223

238224
if IS_WIN32 and ("Library" not in nvvm_ctk):
239225
# Location specific to CUDA 11.x packages on Windows
240-
dirs = ('Lib', 'x64')
226+
dirs = ("Lib", "x64")
241227
else:
242228
# Linux, or Windows with CUDA 12.x packages
243-
dirs = ('lib',)
229+
dirs = ("lib",)
244230

245231
env_dir = os.path.dirname(os.path.dirname(nvvm_ctk))
246232
return os.path.join(env_dir, *dirs)
@@ -251,10 +237,10 @@ def get_cuda_home(*subdirs):
251237
If *subdirs* are the subdirectory name to be appended in the resulting
252238
path.
253239
"""
254-
cuda_home = os.environ.get('CUDA_HOME')
240+
cuda_home = os.environ.get("CUDA_HOME")
255241
if cuda_home is None:
256242
# Try Windows CUDA installation without Anaconda
257-
cuda_home = os.environ.get('CUDA_PATH')
243+
cuda_home = os.environ.get("CUDA_PATH")
258244
if cuda_home is not None:
259245
return os.path.join(cuda_home, *subdirs)
260246

@@ -265,7 +251,7 @@ def _get_nvvm_path():
265251
# The NVVM path is a directory, not a file
266252
path = os.path.join(path, "libnvvm.so")
267253
else:
268-
candidates = find_lib('nvvm', path)
254+
candidates = find_lib("nvvm", path)
269255
path = max(candidates) if candidates else None
270256
return _env_path_tuple(by, path)
271257

@@ -282,16 +268,16 @@ def get_cuda_paths():
282268
Note: The result of the function is cached.
283269
"""
284270
# Check cache
285-
if hasattr(get_cuda_paths, '_cached_result'):
271+
if hasattr(get_cuda_paths, "_cached_result"):
286272
return get_cuda_paths._cached_result
287273
else:
288274
# Not in cache
289275
d = {
290-
'nvvm': _get_nvvm_path(),
291-
'libdevice': _get_libdevice_paths(),
292-
'cudalib_dir': _get_cudalib_dir(),
293-
'static_cudalib_dir': _get_static_cudalib_dir(),
294-
'include_dir': _get_include_dir(),
276+
"nvvm": _get_nvvm_path(),
277+
"libdevice": _get_libdevice_paths(),
278+
"cudalib_dir": _get_cudalib_dir(),
279+
"static_cudalib_dir": _get_static_cudalib_dir(),
280+
"include_dir": _get_include_dir(),
295281
}
296282
# Cache result
297283
get_cuda_paths._cached_result = d
@@ -303,7 +289,7 @@ def get_debian_pkg_libdevice():
303289
Return the Debian NVIDIA Maintainers-packaged libdevice location, if it
304290
exists.
305291
"""
306-
pkg_libdevice_location = '/usr/lib/nvidia-cuda-toolkit/libdevice'
292+
pkg_libdevice_location = "/usr/lib/nvidia-cuda-toolkit/libdevice"
307293
if not os.path.exists(pkg_libdevice_location):
308294
return None
309295
return pkg_libdevice_location
@@ -332,13 +318,10 @@ def get_current_cuda_target_name():
332318
machine = platform.machine()
333319

334320
if system == "Linux":
335-
arch_to_targets = {
336-
'x86_64': 'x86_64-linux',
337-
'aarch64': 'sbsa-linux'
338-
}
321+
arch_to_targets = {"x86_64": "x86_64-linux", "aarch64": "sbsa-linux"}
339322
elif system == "Windows":
340323
arch_to_targets = {
341-
'AMD64': 'x64',
324+
"AMD64": "x64",
342325
}
343326
else:
344327
arch_to_targets = {}
@@ -351,35 +334,33 @@ def get_conda_include_dir():
351334
Return the include directory in the current conda environment, if one
352335
is active and it exists.
353336
"""
354-
is_conda_env = os.path.exists(os.path.join(sys.prefix, 'conda-meta'))
337+
is_conda_env = os.path.exists(os.path.join(sys.prefix, "conda-meta"))
355338
if not is_conda_env:
356339
return
357340

358341
if platform.system() == "Windows":
359-
include_dir = os.path.join(
360-
sys.prefix, 'Library', 'include'
361-
)
342+
include_dir = os.path.join(sys.prefix, "Library", "include")
362343
elif target_name := get_current_cuda_target_name():
363-
include_dir = os.path.join(
364-
sys.prefix, 'targets', target_name, 'include'
365-
)
344+
include_dir = os.path.join(sys.prefix, "targets", target_name, "include")
366345
else:
367346
# A fallback when target cannot determined
368347
# though usually it shouldn't.
369-
include_dir = os.path.join(sys.prefix, 'include')
348+
include_dir = os.path.join(sys.prefix, "include")
370349

371-
if (os.path.exists(include_dir) and os.path.isdir(include_dir)
372-
and os.path.exists(os.path.join(include_dir,
373-
'cuda_device_runtime_api.h'))):
350+
if (
351+
os.path.exists(include_dir)
352+
and os.path.isdir(include_dir)
353+
and os.path.exists(os.path.join(include_dir, "cuda_device_runtime_api.h"))
354+
):
374355
return include_dir
375356
return
376357

377358

378359
def _get_include_dir():
379360
"""Find the root include directory."""
380361
options = [
381-
('Conda environment (NVIDIA package)', get_conda_include_dir()),
382-
('CUDA_INCLUDE_PATH Config Entry', config.CUDA_INCLUDE_PATH),
362+
("Conda environment (NVIDIA package)", get_conda_include_dir()),
363+
("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
383364
# TODO: add others
384365
]
385366
by, include_dir = _find_valid_path(options)

0 commit comments

Comments
 (0)