Skip to content

Locate nvvm, libdevice, nvrtc, and cudart from nvidia-*-cu12 wheels #155

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
b8238f9
initial
brandon-b-miller Mar 10, 2025
9c56e55
slightly refactor cuda_paths
brandon-b-miller Mar 11, 2025
886b9b0
refactor libdevice search mechanism
brandon-b-miller Mar 11, 2025
7ffef77
debug get_cuda_paths
brandon-b-miller Mar 11, 2025
fcedb13
can launch kernel
brandon-b-miller Mar 11, 2025
151e565
cleanup
brandon-b-miller Mar 11, 2025
b4ededf
style
brandon-b-miller Mar 11, 2025
4f2bc2b
reset files
brandon-b-miller Mar 11, 2025
5f4ed8f
initial ci scripts
brandon-b-miller Mar 18, 2025
d4bf113
add pynvjitlink to tests and enable
brandon-b-miller Mar 18, 2025
443e998
locate nvrtc
brandon-b-miller Mar 22, 2025
1b436c6
working inside container
brandon-b-miller Mar 22, 2025
532d864
somewhat roundabout logic works for system/wheel
brandon-b-miller Mar 23, 2025
59bb493
skip tests with no set bin dir
brandon-b-miller Mar 24, 2025
f5dbee6
ensure builtins on windows
brandon-b-miller Mar 24, 2025
43c3ec2
merge/resolve
brandon-b-miller Mar 31, 2025
6a68eb8
remove system nvrtc from wheel test job
brandon-b-miller Mar 31, 2025
3dbd42d
refactor _get_nvvm_wheel
brandon-b-miller Mar 31, 2025
3f4ca51
actually install nvrtc from wheel
brandon-b-miller Apr 1, 2025
3e8651c
merge/resolve
brandon-b-miller Apr 2, 2025
2364278
merge/resolve
brandon-b-miller Apr 8, 2025
42af9ad
ruff
brandon-b-miller Apr 8, 2025
8cf66d0
prioritize wheels over system installs
brandon-b-miller Apr 10, 2025
4f686e9
global search priority
brandon-b-miller Apr 11, 2025
1e3f9a6
Merge branch 'main' into locate-nvvm-nvrtc-wheels
brandon-b-miller Apr 14, 2025
3d65c3d
local import of driver to determine cuda version
brandon-b-miller Apr 14, 2025
78cebea
short circuit
brandon-b-miller Apr 14, 2025
4e8f73f
bugfix
brandon-b-miller Apr 14, 2025
00bb4da
address reviews
brandon-b-miller Apr 15, 2025
5cebb44
Update numba_cuda/numba/cuda/cuda_paths.py
brandon-b-miller Apr 17, 2025
c14644a
get runtime lib from wheel as well
brandon-b-miller Apr 17, 2025
9b3bad9
remove system packages from conda test jobs
brandon-b-miller Apr 17, 2025
51f7694
only remove system packages in cuda 12 ci test jobs
brandon-b-miller Apr 17, 2025
8cc37d7
address reviews in cuda_paths.py
brandon-b-miller Apr 17, 2025
d5b68a9
add Graham's patch
brandon-b-miller Apr 17, 2025
dfe25c8
source cudart from wheel
brandon-b-miller Apr 17, 2025
9fc2531
simplify logic
brandon-b-miller Apr 18, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
- build-wheels
- test-wheels
- test-wheels-pynvjitlink
- test-wheels-deps-wheels
- build-docs
secrets: inherit
uses: rapidsai/shared-workflows/.github/workflows/[email protected]
Expand Down Expand Up @@ -94,6 +95,14 @@ jobs:
script: "ci/test_wheel_pynvjitlink.sh"
# This selects "ARCH=amd64 and CUDA >=12, with the latest supported Python for each CUDA major version".
matrix_filter: map(select(.ARCH == "amd64" and (.CUDA_VER | split(".") | .[0] | tonumber >= 12))) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))]))
test-wheels-deps-wheels:
needs:
- build-wheels
uses: ./.github/workflows/wheels-test.yaml
with:
build_type: pull-request
script: "ci/test_wheel_deps_wheels.sh"
matrix_filter: map(select(.ARCH == "amd64" and (.CUDA_VER | split(".") | .[0] | tonumber >= 12))) | group_by(.CUDA_VER|split(".")|map(tonumber)|.[0]) | map(max_by([(.PY_VER|split(".")|map(tonumber)), (.CUDA_VER|split(".")|map(tonumber))]))
build-docs:
needs:
- build-conda
Expand Down
55 changes: 55 additions & 0 deletions ci/test_wheel_deps_wheels.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#!/bin/bash
# Copyright (c) 2023-2024, NVIDIA CORPORATION

set -euo pipefail

rapids-logger "Install testing dependencies"
# TODO: Replace with rapids-dependency-file-generator
python -m pip install \
psutil \
cffi \
cuda-python \
nvidia-curand-cu12 \
nvidia-cuda-nvcc-cu12 \
nvidia-cuda-nvrtc-cu12 \
pynvjitlink-cu12 \
pytest


rapids-logger "Build tests"
PY_SCRIPT="
import numba_cuda
root = numba_cuda.__file__.rstrip('__init__.py')
test_dir = root + \"numba/cuda/tests/test_binary_generation/\"
print(test_dir)
"

NUMBA_CUDA_TEST_BIN_DIR=$(python -c "$PY_SCRIPT")
pushd $NUMBA_CUDA_TEST_BIN_DIR
make
popd

rapids-logger "Install wheel"
package=$(realpath wheel/numba_cuda*.whl)
echo "Package path: $package"
python -m pip install $package

rapids-logger "Check GPU usage"
nvidia-smi

RAPIDS_TESTS_DIR=${RAPIDS_TESTS_DIR:-"${PWD}/test-results"}/
mkdir -p "${RAPIDS_TESTS_DIR}"
pushd "${RAPIDS_TESTS_DIR}"

rapids-logger "Show Numba system info"
python -m numba --sysinfo

# remove cuda-nvvm-12-5 leaving libnvvm.so from nvidia-cuda-nvcc-cu12 only
apt-get update
apt remove --purge `dpkg --get-selections | grep cuda-nvvm | awk '{print $1}'` -y
apt remove --purge `dpkg --get-selections | grep cuda-nvrtc | awk '{print $1}'` -y

rapids-logger "Run Tests"
NUMBA_CUDA_ENABLE_PYNVJITLINK=1 NUMBA_CUDA_TEST_BIN_DIR=$NUMBA_CUDA_TEST_BIN_DIR python -m numba.runtests numba.cuda.tests -v

popd
176 changes: 164 additions & 12 deletions numba_cuda/numba/cuda/cuda_paths.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The following changes lead to success on Windows:

diff --git a/numba_cuda/numba/cuda/cuda_paths.py b/numba_cuda/numba/cuda/cuda_paths.py
index 9b38a86..bc3822a 100644
--- a/numba_cuda/numba/cuda/cuda_paths.py
+++ b/numba_cuda/numba/cuda/cuda_paths.py
@@ -224,8 +224,9 @@ def _cuda_home_static_cudalib_path():
 def _get_cudalib_wheel():
     """Get the cudalib path from the NVCC wheel."""
     site_paths = [site.getusersitepackages()] + site.getsitepackages()
+    libdir = IS_LINUX and "lib" or "bin"
     for sp in filter(None, site_paths):
-        cudalib_path = Path(sp, "nvidia", "cuda_runtime", "lib")
+        cudalib_path = Path(sp, "nvidia", "cuda_runtime", libdir)
         if cudalib_path.exists():
             return str(cudalib_path)
     return None
@@ -373,8 +374,20 @@ def get_cuda_home(*subdirs):

 def _get_nvvm_path():
     by, path = _get_nvvm_path_decision()
+
     if by == "NVIDIA NVCC Wheel":
-        path = os.path.join(path, "libnvvm.so")
+        platform_map = {
+            "linux": "libnvvm.so",
+            "win32": "nvvm64_40_0.dll",
+        }
+
+        for plat, dso_name in platform_map.items():
+            if sys.platform.startswith(plat):
+                break
+        else:
+            raise NotImplementedError("Unsupported platform")
+
+        path = os.path.join(path, dso_name)
     else:
         candidates = find_lib("nvvm", path)
         path = max(candidates) if candidates else None

Library test output:

(test-cuda-wheels) PS C:\Users\gmarkall\numbadev\numba-cuda> python -c "from numba import cuda; cuda.cudadrv.libs.test()"
Finding driver from candidates:
        nvcuda.dll
        \windows\system32\nvcuda.dll
Using loader <class 'ctypes.WinDLL'>
        Trying to load driver...        ok
                Loaded from nvcuda.dll
Finding nvvm from NVIDIA NVCC Wheel
        Located at D:\miniforge\envs\test-cuda-wheels\Lib\site-packages\nvidia\cuda_nvcc\nvvm\bin\nvvm64_40_0.dll
        Trying to open library...       ok
Finding nvrtc from NVIDIA NVCC Wheel
        Located at D:\miniforge\envs\test-cuda-wheels\Lib\site-packages\nvidia\cuda_nvrtc\bin\nvrtc64_120_0.dll
        Trying to open library...       ok
Finding cudart from NVIDIA NVCC Wheel
        Located at D:\miniforge\envs\test-cuda-wheels\Lib\site-packages\nvidia\cuda_runtime\bin\cudart64_12.dll
        Trying to open library...       ok
Finding cudadevrt from <unknown>
        Located at cudadevrt.lib
        Checking library...     ERROR: failed to find cudadevrt:
cudadevrt.lib not found
Finding libdevice from NVIDIA NVCC Wheel
        Located at D:\miniforge\envs\test-cuda-wheels\Lib\site-packages\nvidia\cuda_nvcc\nvvm\libdevice\libdevice.10.bc
        Checking library...     ok
Include directory configuration variable:
        CUDA_INCLUDE_PATH=cuda_include_not_found
Finding include directory from CUDA_INCLUDE_PATH Config Entry
        Located at cuda_include_not_found
        Checking include directory...   ERROR: failed to find cuda include directory:

We will just have to ignore that the includes and cudadevrt don't seem to be available in wheels, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

patch added in d5b68a9

Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import os
from collections import namedtuple
import platform

import site
from pathlib import Path
from numba.core.config import IS_WIN32
from numba.misc.findlib import find_lib, find_file
from numba.misc.findlib import find_lib
from numba import config
import glob
import ctypes


_env_path_tuple = namedtuple("_env_path_tuple", ["by", "info"])
Expand All @@ -29,6 +32,7 @@ def _get_libdevice_path_decision():
("Conda environment", get_conda_ctk()),
("Conda environment (NVIDIA package)", get_nvidia_libdevice_ctk()),
("CUDA_HOME", get_cuda_home("nvvm", "libdevice")),
("NVIDIA NVCC Wheel", get_libdevice_wheel()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have wheels ahead of system packages? Like with the Debian package, if the wheel comes after the system package, you can never get the toolkit you want if it's different from the system one by installing a wheel.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good point, will update this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be ahead of CUDA_HOME, too, at the same level as a conda package.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdice @gmarkall should the search priority be global across all the libraries? If so, 4f686e9 might be a good way forward.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it should and we should probably align the search priority with the work @rwgk is doing in cuda-python to build a universal lib finder / path handler.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the NVIDIA/cuda-python#451 epic.

FYI I took the code from this PR a few weeks ago (NVIDIA/cuda-python#447). Random status update: I added some code on top, that passes all testing for the libraries with cuda-bindings (nvJitLink, nvrtc, nvvm). I'm in the middle of locally testing loading "all" other Nvidia libraries.

("System", get_system_ctk("nvvm", "libdevice")),
("Debian package", get_debian_pkg_libdevice()),
]
Expand All @@ -48,19 +52,141 @@ def _get_nvvm_path_decision():
("Conda environment", get_conda_ctk()),
("Conda environment (NVIDIA package)", get_nvidia_nvvm_ctk()),
("CUDA_HOME", get_cuda_home(*_nvvm_lib_dir())),
("NVIDIA NVCC Wheel", _get_nvvm_wheel()),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above on positioning this before the system toolkit install.

("System", get_system_ctk(*_nvvm_lib_dir())),
]

by, path = _find_valid_path(options)
return by, path


def _get_nvrtc_system_ctk():
sys_path = get_system_ctk("bin" if IS_WIN32 else "lib64")
candidates = find_lib("nvrtc", sys_path)
if candidates:
return max(candidates)


def _get_nvrtc_path_decision():
options = [
("CUDA_HOME", get_cuda_home("nvrtc")),
("Conda environment", get_conda_ctk()),
("NVIDIA NVCC Wheel", _get_nvrtc_wheel()),
("System", _get_nvrtc_system_ctk()),
]
by, path = _find_valid_path(options)
return by, path


def _get_nvvm_wheel():
platform_map = {
"linux": ("lib64", "libnvvm.so"),
"win32": ("bin", "nvvm64_40_0.dll"),
}

for plat, (dso_dir, dso_path) in platform_map.items():
if sys.platform.startswith(plat):
break
else:
raise NotImplementedError("Unsupported platform")

site_paths = [site.getusersitepackages()] + site.getsitepackages()

for sp in filter(None, site_paths):
nvvm_path = Path(sp, "nvidia", "cuda_nvcc", "nvvm", dso_dir, dso_path)
if nvvm_path.exists():
return str(nvvm_path.parent)

return None


def detect_nvrtc_major_cuda_version(lib_dir):
# TODO - is this a bad idea?
if sys.platform.startswith("linux"):
pattern = os.path.join(lib_dir, "libnvrtc.so.*")
elif sys.platform.startswith("win32"):
pattern = os.path.join(lib_dir, "nvrtc64_*.dll")
else:
raise NotImplementedError("Unsupported platform")

candidates = glob.glob(pattern)
for lib in candidates:
match = re.search(
r"libnvrtc\.so\.(\d+)(?:\.(\d+))?$"
if sys.platform.startswith("linux")
else r"nvrtc64_(\d+)(\d)_0",
os.path.basename(lib),
)
if match:
major, _ = match.groups()
return int(major)

raise RuntimeError("CUDA version could not be detected")


def get_nvrtc_dso_path():
site_paths = [site.getusersitepackages()] + site.getsitepackages()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW, in my development branch I gravitated to using sys.path. All site packages wind up there, and whatever people want to put there:

https://github.com/NVIDIA/cuda-python/blob/7a0c06870b6260af92f90691f28279cbd40e43eb/cuda_bindings/cuda/bindings/_path_finder/sys_path_find_sub_dirs.py

for sp in site_paths:
lib_dir = os.path.join(
sp,
"nvidia",
"cuda_nvrtc",
("lib" if sys.platform.startswith("linux") else "bin")
if sp
else None,
)
if lib_dir and os.path.exists(lib_dir):
try:
major = detect_nvrtc_major_cuda_version(lib_dir)
if major == 11:
cu_ver = (
"11.2" if sys.platform.startswith("linux") else "112"
)
elif major == 12:
cu_ver = "12" if sys.platform.startswith("linux") else "120"
else:
raise NotImplementedError(f"CUDA {major} is not supported")

return os.path.join(
lib_dir,
f"libnvrtc.so.{cu_ver}"
if sys.platform.startswith("linux")
else f"nvrtc64_{cu_ver}_0.dll",
)
except RuntimeError:
continue


def _get_nvrtc_wheel():
dso_path = get_nvrtc_dso_path()
if dso_path:
try:
result = ctypes.CDLL(dso_path, mode=ctypes.RTLD_GLOBAL)
except OSError:
pass
else:
if sys.platform.startswith("win32"):
import win32api
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think win32api is not part of the Python standard library, and this is the first use of it in numba-cuda. Are we guaranteed that it's available on systems in which we end up here? Do we need to add a new dependency to numba-cuda for it?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this comes from the pywin32 package?

Copy link
Collaborator Author

@brandon-b-miller brandon-b-miller Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're not testing windows in CI I think it's debateable if we can actually inherit this piece of the patch as well as the other windows logic. What do you think? To answer your question, I'd imagine this likely errors on windows.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed that pywin32 is a dependency of cuda-bindings (which is a dependency of cuda-python), so I think we can reasonably assume it to be around in this configuration for the lifetime of the cuda_paths.py file.


# This absolute path will
# always be correct regardless of the package source
nvrtc_path = win32api.GetModuleFileNameW(result._handle)
dso_dir = os.path.dirname(nvrtc_path)
builtins_path = os.path.join(
dso_dir,
[
f
for f in os.listdir(dso_dir)
if re.match("^nvrtc-builtins.*.dll$", f)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you consider using glob.glob(os.path.join(dso_dir, "nvrtc-builtins*.dll")) here?

I'd assert that there is exactly one match to be sure. (On Linux there is libnvrtc-builtins.alt.so, maybe one day it'll pop up on Windows as well?)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hadn't yet, this particular piece of code was written to stay faithful to the nvmath-python patch that this PR is somewhat inheriting from, in particular this logic. I've been scratching my head a bit trying to figure out if we actually need this logic at all. I think @leofang possibly implemented the nvmath-python patch, maybe he can provide some context for the edge case this is covering? We can decide what to do from there.

Copy link

@rwgk rwgk Apr 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been scratching my head a bit trying to figure out if we actually need this logic at all.

I had a lot of fun working on understanding that.

For example:

Archive:  ./all_windows-x86_64_zip/cuda_nvrtc-windows-x86_64-12.8.93-archive.zip
  Length      Date    Time    Name
---------  ---------- -----   ----
...
        0  2025-02-21 21:41   cuda_nvrtc-windows-x86_64-12.8.93-archive/bin/
  6356480  2025-02-21 21:13   cuda_nvrtc-windows-x86_64-12.8.93-archive/bin/nvrtc-builtins64_128.dll
 86794240  2025-02-21 21:13   cuda_nvrtc-windows-x86_64-12.8.93-archive/bin/nvrtc64_120_0.alt.dll
 86728192  2025-02-21 21:13   cuda_nvrtc-windows-x86_64-12.8.93-archive/bin/nvrtc64_120_0.dll
...

nvrtc64_120_0.dll depends on nvrtc-builtins64_128.dll. But unlike under Linux, under Windows the dynamic loader does not automagically look in the same directory for dependent DLLs. Some extra work is needed to ensure that the dependency is found when loading the main DLL.

(I didn't try to fully understand how that's handled in this PR. I'm still working on the exact details of that under my cuda-python PRs.)

(It's even more interesting under Linux, because there is a .alt file also for the builtins .so.)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR just asserts that it exists rather than persisting it to a python global that AFAICT isn't used anywhere. That said, numba-cuda isn't running windows tests in CI, and there's no job that's testing this situation on windows.

I had left it in anyways with the goal of nvmath-python being able to delete their patch instead of having to do more additional patching for their supported platforms, but without testing it in our CI, its tough to know if it's working or not...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The original patch was meant to pre-load all nvrtc64_120_0.dll's dependencies. As long as all needed DLLs are loaded, at JIT time Windows is happy. (Very different behavior from Linux, I know.)

][0],
)
assert os.path.exists(builtins_path)
return Path(dso_path)


def _get_libdevice_paths():
by, libdir = _get_libdevice_path_decision()
# Search for pattern
pat = r"libdevice(\.\d+)*\.bc$"
candidates = find_file(re.compile(pat), libdir)
# Keep only the max (most recent version) of the bitcode files.
out = max(candidates, default=None)
out = os.path.join(libdir, "libdevice.10.bc")
return _env_path_tuple(by, out)


Expand Down Expand Up @@ -116,9 +242,9 @@ def get_system_ctk(*subdirs):
if sys.platform.startswith("linux"):
# Is cuda alias to /usr/local/cuda?
# We are intentionally not getting versioned cuda installation.
base = "/usr/local/cuda"
if os.path.exists(base):
return os.path.join(base, *subdirs)
result = os.path.join("/usr/local/cuda", *subdirs)
if os.path.exists(result):
return result


def get_conda_ctk():
Expand Down Expand Up @@ -211,8 +337,23 @@ def get_cuda_home(*subdirs):

def _get_nvvm_path():
by, path = _get_nvvm_path_decision()
candidates = find_lib("nvvm", path)
path = max(candidates) if candidates else None
if by == "NVIDIA NVCC Wheel":
path = os.path.join(path, "libnvvm.so")
else:
candidates = find_lib("nvvm", path)
path = max(candidates) if candidates else None
return _env_path_tuple(by, path)


def _get_nvrtc_path():
by, path = _get_nvrtc_path_decision()
if by == "NVIDIA NVCC Wheel":
path = str(path)
elif by == "System":
return _env_path_tuple(by, path)
else:
candidates = find_lib("nvrtc", path)
path = max(candidates) if candidates else None
return _env_path_tuple(by, path)


Expand All @@ -234,6 +375,7 @@ def get_cuda_paths():
# Not in cache
d = {
"nvvm": _get_nvvm_path(),
"nvrtc": _get_nvrtc_path(),
"libdevice": _get_libdevice_paths(),
"cudalib_dir": _get_cudalib_dir(),
"static_cudalib_dir": _get_static_cudalib_dir(),
Expand All @@ -255,6 +397,16 @@ def get_debian_pkg_libdevice():
return pkg_libdevice_location


def get_libdevice_wheel():
nvvm_path = _get_nvvm_wheel()
if nvvm_path is None:
return None
nvvm_path = Path(nvvm_path)
libdevice_path = nvvm_path.parent / "libdevice"

return str(libdevice_path)


def get_current_cuda_target_name():
"""Determine conda's CTK target folder based on system and machine arch.

Expand Down
6 changes: 4 additions & 2 deletions numba_cuda/numba/cuda/cudadrv/libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def get_cudalib(lib, static=False):
'libnvvm.so' for 'nvvm') so that we may attempt to load it using the system
loader's search mechanism.
"""
if lib == "nvvm":
return get_cuda_paths()["nvvm"].info or _dllnamepattern % "nvvm"
if lib in {"nvrtc", "nvvm"}:
return get_cuda_paths()[lib].info or _dllnamepattern % lib
else:
dir_type = "static_cudalib_dir" if static else "cudalib_dir"
libdir = get_cuda_paths()[dir_type].info
Expand Down Expand Up @@ -92,6 +92,8 @@ def check_static_lib(path):
def _get_source_variable(lib, static=False):
if lib == "nvvm":
return get_cuda_paths()["nvvm"].by
elif lib == "nvrtc":
return get_cuda_paths()["nvrtc"].by
elif lib == "libdevice":
return get_cuda_paths()["libdevice"].by
elif lib == "include_dir":
Expand Down
Loading