3
3
import functools
4
4
import logging
5
5
import os
6
+ import shutil
6
7
import subprocess
7
8
import sys
8
9
from typing import Any , Callable , Dict , List , Optional , Sequence , Tuple , Union , overload
@@ -926,19 +927,100 @@ def args_bounds_check(
926
927
return args [i ] if len (args ) > i and args [i ] is not None else replacement
927
928
928
929
930
+ def install_wget (platform : str ) -> None :
931
+ if shutil .which ("wget" ):
932
+ _LOGGER .debug ("wget is already installed" )
933
+ return
934
+ if platform .startswith ("linux" ):
935
+ try :
936
+ # if its root
937
+ if os .geteuid () == 0 :
938
+ subprocess .run (["apt-get" , "update" ], check = True )
939
+ subprocess .run (["apt-get" , "install" , "-y" , "wget" ], check = True )
940
+ else :
941
+ _LOGGER .debug ("Please run with sudo permissions" )
942
+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
943
+ subprocess .run (["sudo" , "apt-get" , "install" , "-y" , "wget" ], check = True )
944
+ except subprocess .CalledProcessError as e :
945
+ _LOGGER .debug ("Error installing wget:" , e )
946
+
947
+
948
+ def install_mpi (platform : str ) -> None :
949
+ if platform .startswith ("linux" ):
950
+ try :
951
+ # if its root
952
+ if os .geteuid () == 0 :
953
+ subprocess .run (["apt-get" , "update" ], check = True )
954
+ subprocess .run (["apt-get" , "install" , "-y" , "libmpich-dev" ], check = True )
955
+ subprocess .run (
956
+ ["apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
957
+ )
958
+ else :
959
+ _LOGGER .debug ("Please run with sudo permissions" )
960
+ subprocess .run (["sudo" , "apt-get" , "update" ], check = True )
961
+ subprocess .run (
962
+ ["sudo" , "apt-get" , "install" , "-y" , "libmpich-dev" ], check = True
963
+ )
964
+ subprocess .run (
965
+ ["sudo" , "apt-get" , "install" , "-y" , "libopenmpi-dev" ], check = True
966
+ )
967
+ except subprocess .CalledProcessError as e :
968
+ _LOGGER .debug ("Error installing mpi libs:" , e )
969
+
970
+
971
+ def download_plugin_lib_path (py_version : str , platform : str ) -> str :
972
+ plugin_lib_path = None
973
+ if py_version not in ("cp310" , "cp312" ):
974
+ _LOGGER .warning (
975
+ "No available wheel for python versions other than py3.10 and py3.12"
976
+ )
977
+ install_wget (platform )
978
+ base_url = "https://pypi.nvidia.com/tensorrt-llm/"
979
+ file_name = f"tensorrt_llm-0.17.0.post1-{ py_version } -{ py_version } -{ platform } .whl"
980
+ print ("file_name is===" , file_name )
981
+ download_url = base_url + file_name
982
+ cmd = ["wget" , download_url ]
983
+ try :
984
+ if not (os .path .exists (file_name )):
985
+ _LOGGER .info (f"Running command: { ' ' .join (cmd )} " )
986
+ subprocess .run (cmd )
987
+ _LOGGER .info ("Download complete of wheel" )
988
+ if os .path .exists (file_name ):
989
+ _LOGGER .info ("filename now present" )
990
+ if os .path .exists ("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so" ):
991
+ plugin_lib_path = (
992
+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
993
+ )
994
+ else :
995
+ import zipfile
996
+
997
+ with zipfile .ZipFile (file_name , "r" ) as zip_ref :
998
+ zip_ref .extractall ("." ) # Extract to a folder named 'tensorrt_llm'
999
+ plugin_lib_path = (
1000
+ "./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1001
+ )
1002
+ except subprocess .CalledProcessError as e :
1003
+ _LOGGER .debug (f"Error occurred while trying to download: { e } " )
1004
+ except Exception as e :
1005
+ _LOGGER .debug (f"An unexpected error occurred: { e } " )
1006
+ return plugin_lib_path
1007
+
1008
+
929
1009
def load_tensorrt_llm () -> bool :
930
1010
"""
931
1011
Attempts to load the TensorRT-LLM plugin and initialize it.
932
1012
933
1013
Returns:
934
1014
bool: True if the plugin was successfully loaded and initialized, False otherwise.
935
1015
"""
936
-
1016
+ print ( "coming to check load_tensorrt_llm!!!!" )
937
1017
plugin_lib_path = os .environ .get ("TRTLLM_PLUGINS_PATH" )
938
1018
if not plugin_lib_path :
939
1019
_LOGGER .warning (
940
1020
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library" ,
941
1021
)
1022
+ for key , value in os .environ .items ():
1023
+ print (f"{ key } : { value } " )
942
1024
use_trtllm_plugin = os .environ .get ("USE_TRTLLM_PLUGINS" , "0" ).lower () in (
943
1025
"1" ,
944
1026
"true" ,
@@ -953,38 +1035,12 @@ def load_tensorrt_llm() -> bool:
953
1035
else :
954
1036
py_version = f"cp{ sys .version_info .major } { sys .version_info .minor } "
955
1037
platform = Platform .current_platform ()
956
- if Platform == Platform .LINUX_X86_64 :
957
- platform = "linux_x86_64"
958
- elif Platform == Platform .LINUX_AARCH64 :
959
- platform = "linux_aarch64"
960
-
961
- if py_version not in ("cp310" , "cp312" ):
962
- _LOGGER .warning (
963
- "No available wheel for python versions other than py3.10 and py3.12"
964
- )
965
- if py_version == "cp310" and platform == "linux_aarch64" :
966
- _LOGGER .warning ("No available wheel for python3.10 with Linux aarch64" )
967
1038
968
- base_url = "https://pypi.nvidia.com/tensorrt-llm/"
969
- file_name = (
970
- "tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
971
- )
972
- download_url = base_url + file_name
973
- cmd = ["wget" , download_url ]
974
- subprocess .run (cmd )
975
- if os .path .exists (file_name ):
976
- _LOGGER .info ("filename download is completed" )
977
- import zipfile
978
-
979
- with zipfile .ZipFile (file_name , "r" ) as zip_ref :
980
- zip_ref .extractall (
981
- "./tensorrt_llm"
982
- ) # Extract to a folder named 'tensorrt_llm'
983
- plugin_lib_path = (
984
- "./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
985
- )
1039
+ platform = str (platform ).lower ()
1040
+ plugin_lib_path = download_plugin_lib_path (py_version , platform )
986
1041
try :
987
- # Load the shared library
1042
+ # Load the shared
1043
+ install_mpi (platform )
988
1044
handle = ctypes .CDLL (plugin_lib_path )
989
1045
_LOGGER .info (f"Successfully loaded plugin library: { plugin_lib_path } " )
990
1046
except OSError as e_os_error :
0 commit comments