Skip to content

Commit 57dbb3f

Browse files
committed
TensorRT-LLM import fix and aot_joint_export specify as explicit setting in dynamo.compile
1 parent 2368e63 commit 57dbb3f

File tree

2 files changed

+92
-46
lines changed

2 files changed

+92
-46
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+9
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def cross_compile_for_windows(
9696
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
9797
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
99+
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
99100
**kwargs: Any,
100101
) -> torch.fx.GraphModule:
101102
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -169,6 +170,7 @@ def cross_compile_for_windows(
169170
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
170171
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
171172
enable_weight_streaming (bool): Enable weight streaming.
173+
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
172174
**kwargs: Any,
173175
Returns:
174176
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -326,6 +328,7 @@ def cross_compile_for_windows(
326328
"immutable_weights": immutable_weights,
327329
"enable_cross_compile_for_windows": True,
328330
"enable_weight_streaming": enable_weight_streaming,
331+
"use_aot_joint_export": use_aot_joint_export,
329332
}
330333

331334
# disable the following settings is not supported for cross compilation for windows feature
@@ -413,6 +416,7 @@ def compile(
413416
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
414417
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
415418
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
419+
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
416420
**kwargs: Any,
417421
) -> torch.fx.GraphModule:
418422
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -488,6 +492,7 @@ def compile(
488492
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
489493
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
490494
enable_weight_streaming (bool): Enable weight streaming.
495+
491496
**kwargs: Any,
492497
Returns:
493498
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -662,6 +667,7 @@ def compile(
662667
"immutable_weights": immutable_weights,
663668
"enable_cross_compile_for_windows": False,
664669
"enable_weight_streaming": enable_weight_streaming,
670+
"use_aot_joint_export": use_aot_joint_export,
665671
}
666672

667673
settings = CompilationSettings(**compilation_options)
@@ -950,6 +956,7 @@ def convert_exported_program_to_serialized_trt_engine(
950956
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
951957
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
952958
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
959+
use_aot_joint_export: bool = _defaults.USE_AOT_JOINT_EXPORT,
953960
**kwargs: Any,
954961
) -> bytes:
955962
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1013,6 +1020,7 @@ def convert_exported_program_to_serialized_trt_engine(
10131020
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
10141021
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
10151022
enable_weight_streaming (bool): Enable weight streaming.
1023+
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
10161024
Returns:
10171025
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10181026
"""
@@ -1129,6 +1137,7 @@ def convert_exported_program_to_serialized_trt_engine(
11291137
"strip_engine_weights": strip_engine_weights,
11301138
"immutable_weights": immutable_weights,
11311139
"enable_weight_streaming": enable_weight_streaming,
1140+
"use_aot_joint_export": use_aot_joint_export,
11321141
}
11331142

11341143
settings = CompilationSettings(**compilation_options)

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+83-46
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import functools
44
import logging
55
import os
6+
import subprocess
7+
import sys
68
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
79

810
import numpy as np
@@ -12,6 +14,7 @@
1214
from torch.fx.node import Argument, Target
1315
from torch.fx.passes.shape_prop import TensorMetadata
1416
from torch_tensorrt import _enums
17+
from torch_tensorrt._enums import Platform
1518
from torch_tensorrt.dynamo._settings import CompilationSettings
1619
from torch_tensorrt.dynamo._SourceIR import SourceIR
1720
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
@@ -930,57 +933,91 @@ def load_tensorrt_llm() -> bool:
930933
Returns:
931934
bool: True if the plugin was successfully loaded and initialized, False otherwise.
932935
"""
933-
try:
934-
import tensorrt_llm as trt_llm # noqa: F401
935936

936-
_LOGGER.info("TensorRT-LLM successfully imported")
937-
return True
938-
except (ImportError, AssertionError) as e_import_error:
939-
# Check for environment variable for the plugin library path
940-
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
941-
if not plugin_lib_path:
937+
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
938+
if not plugin_lib_path:
939+
_LOGGER.warning(
940+
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library",
941+
)
942+
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
943+
"1",
944+
"true",
945+
"yes",
946+
"on",
947+
)
948+
if not use_trtllm_plugin:
942949
_LOGGER.warning(
943-
"TensorRT-LLM is not installed. Please install TensorRT-LLM or set TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops",
950+
"Neither TRTLLM_PLUGIN_PATH is set nor is it directed to download the shared library"
944951
)
945952
return False
946-
947-
_LOGGER.info(f"TensorRT-LLM Plugin lib path found: {plugin_lib_path}")
948-
try:
949-
# Load the shared library
950-
handle = ctypes.CDLL(plugin_lib_path)
951-
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
952-
except OSError as e_os_error:
953-
_LOGGER.error(
954-
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
955-
f"Ensure the path is correct and the library is compatible",
956-
exc_info=e_os_error,
953+
else:
954+
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
955+
platform = Platform.current_platform()
956+
if Platform == Platform.LINUX_X86_64:
957+
platform = "linux_x86_64"
958+
elif Platform == Platform.LINUX_AARCH64:
959+
platform = "linux_aarch64"
960+
961+
if py_version not in ("cp310", "cp312"):
962+
_LOGGER.warning(
963+
"No available wheel for python versions other than py3.10 and py3.12"
964+
)
965+
if py_version == "cp310" and platform == "linux_aarch64":
966+
_LOGGER.warning("No available wheel for python3.10 with Linux aarch64")
967+
968+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
969+
file_name = (
970+
"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
957971
)
958-
return False
972+
download_url = base_url + file_name
973+
cmd = ["wget", download_url]
974+
subprocess.run(cmd)
975+
if os.path.exists(file_name):
976+
_LOGGER.info("filename download is completed")
977+
import zipfile
978+
979+
with zipfile.ZipFile(file_name, "r") as zip_ref:
980+
zip_ref.extractall(
981+
"./tensorrt_llm"
982+
) # Extract to a folder named 'tensorrt_llm'
983+
plugin_lib_path = (
984+
"./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
985+
)
986+
try:
987+
# Load the shared library
988+
handle = ctypes.CDLL(plugin_lib_path)
989+
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
990+
except OSError as e_os_error:
991+
_LOGGER.error(
992+
f"Failed to load libnvinfer_plugin_tensorrt_llm.so from {plugin_lib_path}"
993+
f"Ensure the path is correct and the library is compatible",
994+
exc_info=e_os_error,
995+
)
996+
return False
959997

960-
try:
961-
# Configure plugin initialization arguments
962-
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
963-
handle.initTrtLlmPlugins.restype = ctypes.c_bool
964-
except AttributeError as e_plugin_unavailable:
965-
_LOGGER.warning(
966-
"Unable to initialize the TensorRT-LLM plugin library",
967-
exc_info=e_plugin_unavailable,
968-
)
969-
return False
998+
try:
999+
# Configure plugin initialization arguments
1000+
handle.initTrtLlmPlugins.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
1001+
handle.initTrtLlmPlugins.restype = ctypes.c_bool
1002+
except AttributeError as e_plugin_unavailable:
1003+
_LOGGER.warning(
1004+
"Unable to initialize the TensorRT-LLM plugin library",
1005+
exc_info=e_plugin_unavailable,
1006+
)
1007+
return False
9701008

971-
try:
972-
# Initialize the plugin
973-
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
974-
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
975-
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
976-
return True
977-
else:
978-
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
979-
return False
980-
except Exception as e_initialization_error:
981-
_LOGGER.warning(
982-
"Exception occurred during TensorRT-LLM plugin library initialization",
983-
exc_info=e_initialization_error,
984-
)
1009+
try:
1010+
# Initialize the plugin
1011+
TRT_LLM_PLUGIN_NAMESPACE = "tensorrt_llm"
1012+
if handle.initTrtLlmPlugins(None, TRT_LLM_PLUGIN_NAMESPACE.encode("utf-8")):
1013+
_LOGGER.info("TensorRT-LLM plugin successfully initialized")
1014+
return True
1015+
else:
1016+
_LOGGER.warning("TensorRT-LLM plugin library failed in initialization")
9851017
return False
986-
return False
1018+
except Exception as e_initialization_error:
1019+
_LOGGER.warning(
1020+
"Exception occurred during TensorRT-LLM plugin library initialization",
1021+
exc_info=e_initialization_error,
1022+
)
1023+
return False

0 commit comments

Comments
 (0)