|
3 | 3 | import re
|
4 | 4 | import site
|
5 | 5 | import sys
|
| 6 | +import traceback |
| 7 | +import warnings |
6 | 8 | from collections import namedtuple
|
7 | 9 | from pathlib import Path
|
8 | 10 |
|
9 |
| -from numba import config |
10 |
| -from numba.core.config import IS_WIN32 |
11 |
| -from numba.misc.findlib import find_file, find_lib |
| 11 | +from .findlib import find_file, find_lib |
| 12 | + |
| 13 | +IS_WIN32 = sys.platform.startswith("win32") |
12 | 14 |
|
13 | 15 | _env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
|
14 | 16 |
|
15 | 17 |
|
| 18 | +def _get_numba_CUDA_INCLUDE_PATH(): |
| 19 | + # From numba/numba/core/config.py |
| 20 | + |
| 21 | + def _readenv(name, ctor, default): |
| 22 | + value = os.environ.get(name) |
| 23 | + if value is None: |
| 24 | + return default() if callable(default) else default |
| 25 | + try: |
| 26 | + return ctor(value) |
| 27 | + except Exception: |
| 28 | + warnings.warn( # noqa: B028 |
| 29 | + f"Environment variable '{name}' is defined but " |
| 30 | + f"its associated value '{value}' could not be " |
| 31 | + "parsed.\nThe parse failed with exception:\n" |
| 32 | + f"{traceback.format_exc()}", |
| 33 | + RuntimeWarning, |
| 34 | + ) |
| 35 | + return default |
| 36 | + |
| 37 | + if IS_WIN32: |
| 38 | + cuda_path = os.environ.get("CUDA_PATH") |
| 39 | + if cuda_path: # noqa: SIM108 |
| 40 | + default_cuda_include_path = os.path.join(cuda_path, "include") |
| 41 | + else: |
| 42 | + default_cuda_include_path = "cuda_include_not_found" |
| 43 | + else: |
| 44 | + default_cuda_include_path = os.path.join(os.sep, "usr", "local", "cuda", "include") |
| 45 | + CUDA_INCLUDE_PATH = _readenv("NUMBA_CUDA_INCLUDE_PATH", str, default_cuda_include_path) |
| 46 | + return CUDA_INCLUDE_PATH |
| 47 | + |
| 48 | + |
| 49 | +config_CUDA_INCLUDE_PATH = _get_numba_CUDA_INCLUDE_PATH() |
| 50 | + |
| 51 | + |
16 | 52 | def _find_valid_path(options):
|
17 | 53 | """Find valid path from *options*, which is a list of 2-tuple of
|
18 | 54 | (name, path). Return first pair where *path* is not None.
|
@@ -221,7 +257,7 @@ def get_nvidia_static_cudalib_ctk():
|
221 | 257 | if not nvvm_ctk:
|
222 | 258 | return
|
223 | 259 |
|
224 |
| - if IS_WIN32 and ("Library" not in nvvm_ctk): |
| 260 | + if IS_WIN32 and ("Library" not in nvvm_ctk): # noqa: SIM108 |
225 | 261 | # Location specific to CUDA 11.x packages on Windows
|
226 | 262 | dirs = ("Lib", "x64")
|
227 | 263 | else:
|
@@ -360,7 +396,7 @@ def _get_include_dir():
|
360 | 396 | """Find the root include directory."""
|
361 | 397 | options = [
|
362 | 398 | ("Conda environment (NVIDIA package)", get_conda_include_dir()),
|
363 |
| - ("CUDA_INCLUDE_PATH Config Entry", config.CUDA_INCLUDE_PATH), |
| 399 | + ("CUDA_INCLUDE_PATH Config Entry", config_CUDA_INCLUDE_PATH), |
364 | 400 | # TODO: add others
|
365 | 401 | ]
|
366 | 402 | by, include_dir = _find_valid_path(options)
|
|
0 commit comments