Skip to content

Commit 0c5aca5

Browse files
committed
Minimal changes to adapt numba-cuda/numba_cuda/numba/cuda/cuda_paths.py from NVIDIA/numba-cuda#155
1 parent ed0ebb3 commit 0c5aca5

File tree

2 files changed

+46
-5
lines changed

2 files changed

+46
-5
lines changed

cuda_bindings/cuda/bindings/ecosystem/cuda_paths.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,52 @@
33
import re
44
import site
55
import sys
6+
import traceback
7+
import warnings
68
from collections import namedtuple
79
from pathlib import Path
810

9-
from numba import config
10-
from numba.core.config import IS_WIN32
11-
from numba.misc.findlib import find_file, find_lib
11+
from .findlib import find_file, find_lib
12+
13+
IS_WIN32 = sys.platform.startswith("win32")
1214

1315
_env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
1416

1517

18+
def _get_numba_CUDA_INCLUDE_PATH():
19+
# From numba/numba/core/config.py
20+
21+
def _readenv(name, ctor, default):
22+
value = os.environ.get(name)
23+
if value is None:
24+
return default() if callable(default) else default
25+
try:
26+
return ctor(value)
27+
except Exception:
28+
warnings.warn( # noqa: B028
29+
f"Environment variable '{name}' is defined but "
30+
f"its associated value '{value}' could not be "
31+
"parsed.\nThe parse failed with exception:\n"
32+
f"{traceback.format_exc()}",
33+
RuntimeWarning,
34+
)
35+
return default
36+
37+
if IS_WIN32:
38+
cuda_path = os.environ.get("CUDA_PATH")
39+
if cuda_path: # noqa: SIM108
40+
default_cuda_include_path = os.path.join(cuda_path, "include")
41+
else:
42+
default_cuda_include_path = "cuda_include_not_found"
43+
else:
44+
default_cuda_include_path = os.path.join(os.sep, "usr", "local", "cuda", "include")
45+
CUDA_INCLUDE_PATH = _readenv("NUMBA_CUDA_INCLUDE_PATH", str, default_cuda_include_path)
46+
return CUDA_INCLUDE_PATH
47+
48+
49+
config_CUDA_INCLUDE_PATH = _get_numba_CUDA_INCLUDE_PATH()
50+
51+
1652
def _find_valid_path(options):
1753
"""Find valid path from *options*, which is a list of 2-tuple of
1854
(name, path). Return first pair where *path* is not None.
@@ -221,7 +257,7 @@ def get_nvidia_static_cudalib_ctk():
221257
if not nvvm_ctk:
222258
return
223259

224-
if IS_WIN32 and ("Library" not in nvvm_ctk):
260+
if IS_WIN32 and ("Library" not in nvvm_ctk): # noqa: SIM108
225261
# Location specific to CUDA 11.x packages on Windows
226262
dirs = ("Lib", "x64")
227263
else:
@@ -360,7 +396,7 @@ def _get_include_dir():
360396
"""Find the root include directory."""
361397
options = [
362398
("Conda environment (NVIDIA package)", get_conda_include_dir()),
363-
("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH),
399+
("CUDA_INCLUDE_PATH Config Entry", config_CUDA_INCLUDE_PATH),
364400
# TODO: add others
365401
]
366402
by, include_dir = _find_valid_path(options)
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from cuda.bindings.ecosystem import cuda_paths
2+
3+
paths = cuda_paths.get_cuda_paths()
4+
for k, v in cuda_paths.get_cuda_paths().items():
5+
print(f"{k}: {v}", flush=True)

0 commit comments

Comments
 (0)