|
3 | 3 | import functools
|
4 | 4 | import logging
|
5 | 5 | import os
|
| 6 | +import subprocess |
| 7 | +import sys |
6 | 8 | from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
|
7 | 9 |
|
8 | 10 | import numpy as np
|
|
12 | 14 | from torch.fx.node import Argument, Target
|
13 | 15 | from torch.fx.passes.shape_prop import TensorMetadata
|
14 | 16 | from torch_tensorrt import _enums
|
| 17 | +from torch_tensorrt._enums import Platform |
15 | 18 | from torch_tensorrt.dynamo._settings import CompilationSettings
|
16 | 19 | from torch_tensorrt.dynamo._SourceIR import SourceIR
|
17 | 20 | from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
|
@@ -930,57 +933,91 @@ def load_tensorrt_llm() -> bool:
|
930 | 933 | Returns:
|
931 | 934 | bool: True if the plugin was successfully loaded and initialized, False otherwise.
|
932 | 935 | """
|
933 |
| - try: |
934 |
| - import tensorrt_llm as trt_llm # noqa: F401 |
935 | 936 |
|
936 |
| - _LOGGER.info("TensorRT-LLM successfully imported") |
937 |
| - return True |
938 |
| - except (ImportError, AssertionError) as e_import_error: |
939 |
| - # Check for environment variable for the plugin library path |
940 |
| - plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
941 |
| - if not plugin_lib_path: |
| 937 | + plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH") |
| 938 | + if not plugin_lib_path: |
| 939 | + _LOGGER.warning( |
| 940 | + "Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library", |
| 941 | + ) |
| 942 | + use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in ( |
| 943 | + "1", |
| 944 | + "true", |
| 945 | + "yes", |
| 946 | + "on", |
| 947 | + ) |
| 948 | + if not use_trtllm_plugin: |
942 | 949 | _LOGGER.warning(
|
943 |
| - "TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops", |
| 950 | + "Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library" |
944 | 951 | )
|
945 | 952 | return False
|
946 |
| - |
947 |
| - _LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}") |
948 |
| - try: |
949 |
| - # Load the shared library |
950 |
| - handle = ctypes.CDLL(plugin_lib_path) |
951 |
| - _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
952 |
| - except OSError as e_os_error: |
953 |
| - _LOGGER.error( |
954 |
| - f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" |
955 |
| - f"Ensure the path is correct and the library is compatible", |
956 |
| - exc_info=e_os_error, |
| 953 | + else: |
| 954 | + py_version = f"cp{sys.version_info.major}{sys.version_info.minor}" |
| 955 | + platform = Platform.current_platform() |
| 956 | + if Platform == Platform.LINUX_X86_64: |
| 957 | + platform = "linux_x86_64" |
| 958 | + elif Platform == Platform.LINUX_AARCH64: |
| 959 | + platform = "linux_aarch64" |
| 960 | + |
| 961 | + if py_version not in ("cp310", "cp312"): |
| 962 | + _LOGGER.warning( |
| 963 | + "No available wheel for python versions other than py3.10 and py3.12" |
| 964 | + ) |
| 965 | + if py_version == "cp310" and platform == "linux_aarch64": |
| 966 | + _LOGGER.warning("No available wheel for python3.10 with Linux aarch64") |
| 967 | + |
| 968 | + base_url = "https://pypi.nvidia.com/tensorrt-llm/" |
| 969 | + file_name = ( |
| 970 | + "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl" |
957 | 971 | )
|
958 |
| - return False |
| 972 | + download_url = base_url + file_name |
| 973 | + cmd = ["wget", download_url] |
| 974 | + subprocess.run(cmd) |
| 975 | + if os.path.exists(file_name): |
| 976 | + _LOGGER.info("filename download is completed") |
| 977 | + import zipfile |
| 978 | + |
| 979 | + with zipfile.ZipFile(file_name, "r") as zip_ref: |
| 980 | + zip_ref.extractall( |
| 981 | + "./tensorrt_llm" |
| 982 | + ) # Extract to a folder named 'tensorrt_llm' |
| 983 | + plugin_lib_path = ( |
| 984 | + "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so" |
| 985 | + ) |
| 986 | + try: |
| 987 | + # Load the shared library |
| 988 | + handle = ctypes.CDLL(plugin_lib_path) |
| 989 | + _LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}") |
| 990 | + except OSError as e_os_error: |
| 991 | + _LOGGER.error( |
| 992 | + f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}" |
| 993 | + f"Ensure the path is correct and the library is compatible", |
| 994 | + exc_info=e_os_error, |
| 995 | + ) |
| 996 | + return False |
959 | 997 |
|
960 |
| - try: |
961 |
| - # Configure plugin initialization arguments |
962 |
| - handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
963 |
| - handle.initTrtLlmPlugins.restype = ctypes.c_bool |
964 |
| - except AttributeError as e_plugin_unavailable: |
965 |
| - _LOGGER.warning( |
966 |
| - "Unable to initialize the TensorRT-LLM plugin library", |
967 |
| - exc_info=e_plugin_unavailable, |
968 |
| - ) |
969 |
| - return False |
| 998 | + try: |
| 999 | + # Configure plugin initialization arguments |
| 1000 | + handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p] |
| 1001 | + handle.initTrtLlmPlugins.restype = ctypes.c_bool |
| 1002 | + except AttributeError as e_plugin_unavailable: |
| 1003 | + _LOGGER.warning( |
| 1004 | + "Unable to initialize the TensorRT-LLM plugin library", |
| 1005 | + exc_info=e_plugin_unavailable, |
| 1006 | + ) |
| 1007 | + return False |
970 | 1008 |
|
971 |
| - try: |
972 |
| - # Initialize the plugin |
973 |
| - TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" |
974 |
| - if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): |
975 |
| - _LOGGER.info("TensorRT-LLM plugin successfully initialized") |
976 |
| - return True |
977 |
| - else: |
978 |
| - _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") |
979 |
| - return False |
980 |
| - except Exception as e_initialization_error: |
981 |
| - _LOGGER.warning( |
982 |
| - "Exception occurred during TensorRT-LLM plugin library initialization", |
983 |
| - exc_info=e_initialization_error, |
984 |
| - ) |
| 1009 | + try: |
| 1010 | + # Initialize the plugin |
| 1011 | + TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm" |
| 1012 | + if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")): |
| 1013 | + _LOGGER.info("TensorRT-LLM plugin successfully initialized") |
| 1014 | + return True |
| 1015 | + else: |
| 1016 | + _LOGGER.warning("TensorRT-LLM plugin library failed in initialization") |
985 | 1017 | return False
|
986 |
| - return False |
| 1018 | + except Exception as e_initialization_error: |
| 1019 | + _LOGGER.warning( |
| 1020 | + "Exception occurred during TensorRT-LLM plugin library initialization", |
| 1021 | + exc_info=e_initialization_error, |
| 1022 | + ) |
| 1023 | + return False |
0 commit comments