Skip to content

Commit 9ba407b

Browse files
committed
TRT-LLM installation utilities and adding test cases
1 parent 3e38e87 commit 9ba407b

File tree

3 files changed

+180
-31
lines changed

3 files changed

+180
-31
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

+87-31
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import logging
55
import os
6+
import shutil
67
import subprocess
78
import sys
89
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, overload
@@ -926,19 +927,100 @@ def args_bounds_check(
926927
return args[i] if len(args) > i and args[i] is not None else replacement
927928

928929

930+
def install_wget(platform: str) -> None:
931+
if shutil.which("wget"):
932+
_LOGGER.debug("wget is already installed")
933+
return
934+
if platform.startswith("linux"):
935+
try:
936+
# if its root
937+
if os.geteuid() == 0:
938+
subprocess.run(["apt-get", "update"], check=True)
939+
subprocess.run(["apt-get", "install", "-y", "wget"], check=True)
940+
else:
941+
_LOGGER.debug("Please run with sudo permissions")
942+
subprocess.run(["sudo", "apt-get", "update"], check=True)
943+
subprocess.run(["sudo", "apt-get", "install", "-y", "wget"], check=True)
944+
except subprocess.CalledProcessError as e:
945+
_LOGGER.debug("Error installing wget:", e)
946+
947+
948+
def install_mpi(platform: str) -> None:
949+
if platform.startswith("linux"):
950+
try:
951+
# if its root
952+
if os.geteuid() == 0:
953+
subprocess.run(["apt-get", "update"], check=True)
954+
subprocess.run(["apt-get", "install", "-y", "libmpich-dev"], check=True)
955+
subprocess.run(
956+
["apt-get", "install", "-y", "libopenmpi-dev"], check=True
957+
)
958+
else:
959+
_LOGGER.debug("Please run with sudo permissions")
960+
subprocess.run(["sudo", "apt-get", "update"], check=True)
961+
subprocess.run(
962+
["sudo", "apt-get", "install", "-y", "libmpich-dev"], check=True
963+
)
964+
subprocess.run(
965+
["sudo", "apt-get", "install", "-y", "libopenmpi-dev"], check=True
966+
)
967+
except subprocess.CalledProcessError as e:
968+
_LOGGER.debug("Error installing mpi libs:", e)
969+
970+
971+
def download_plugin_lib_path(py_version: str, platform: str) -> str:
972+
plugin_lib_path = None
973+
if py_version not in ("cp310", "cp312"):
974+
_LOGGER.warning(
975+
"No available wheel for python versions other than py3.10 and py3.12"
976+
)
977+
install_wget(platform)
978+
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
979+
file_name = f"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
980+
print("file_name is===", file_name)
981+
download_url = base_url + file_name
982+
cmd = ["wget", download_url]
983+
try:
984+
if not (os.path.exists(file_name)):
985+
_LOGGER.info(f"Running command: {' '.join(cmd)}")
986+
subprocess.run(cmd)
987+
_LOGGER.info("Download complete of wheel")
988+
if os.path.exists(file_name):
989+
_LOGGER.info("filename now present")
990+
if os.path.exists("./tensorrt_llm/libs/libnvinfer_plugin_tensorrt_llm.so"):
991+
plugin_lib_path = (
992+
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
993+
)
994+
else:
995+
import zipfile
996+
997+
with zipfile.ZipFile(file_name, "r") as zip_ref:
998+
zip_ref.extractall(".") # Extract to a folder named 'tensorrt_llm'
999+
plugin_lib_path = (
1000+
"./tensorrt_llm/libs/" + "libnvinfer_plugin_tensorrt_llm.so"
1001+
)
1002+
except subprocess.CalledProcessError as e:
1003+
_LOGGER.debug(f"Error occurred while trying to download: {e}")
1004+
except Exception as e:
1005+
_LOGGER.debug(f"An unexpected error occurred: {e}")
1006+
return plugin_lib_path
1007+
1008+
9291009
def load_tensorrt_llm() -> bool:
9301010
"""
9311011
Attempts to load the TensorRT-LLM plugin and initialize it.
9321012
9331013
Returns:
9341014
bool: True if the plugin was successfully loaded and initialized, False otherwise.
9351015
"""
936-
1016+
print("coming to check load_tensorrt_llm!!!!")
9371017
plugin_lib_path = os.environ.get("TRTLLM_PLUGINS_PATH")
9381018
if not plugin_lib_path:
9391019
_LOGGER.warning(
9401020
"Please set the TRTLLM_PLUGINS_PATH to the directory containing libnvinfer_plugin_tensorrt_llm.so to use converters for torch.distributed ops or else set the USE_TRTLLM_PLUGINS variable to download the shared library",
9411021
)
1022+
for key, value in os.environ.items():
1023+
print(f"{key}: {value}")
9421024
use_trtllm_plugin = os.environ.get("USE_TRTLLM_PLUGINS", "0").lower() in (
9431025
"1",
9441026
"true",
@@ -953,38 +1035,12 @@ def load_tensorrt_llm() -> bool:
9531035
else:
9541036
py_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
9551037
platform = Platform.current_platform()
956-
if Platform == Platform.LINUX_X86_64:
957-
platform = "linux_x86_64"
958-
elif Platform == Platform.LINUX_AARCH64:
959-
platform = "linux_aarch64"
960-
961-
if py_version not in ("cp310", "cp312"):
962-
_LOGGER.warning(
963-
"No available wheel for python versions other than py3.10 and py3.12"
964-
)
965-
if py_version == "cp310" and platform == "linux_aarch64":
966-
_LOGGER.warning("No available wheel for python3.10 with Linux aarch64")
9671038

968-
base_url = "https://pypi.nvidia.com/tensorrt-llm/"
969-
file_name = (
970-
"tensorrt_llm-0.17.0.post1-{py_version}-{py_version}-{platform}.whl"
971-
)
972-
download_url = base_url + file_name
973-
cmd = ["wget", download_url]
974-
subprocess.run(cmd)
975-
if os.path.exists(file_name):
976-
_LOGGER.info("filename download is completed")
977-
import zipfile
978-
979-
with zipfile.ZipFile(file_name, "r") as zip_ref:
980-
zip_ref.extractall(
981-
"./tensorrt_llm"
982-
) # Extract to a folder named 'tensorrt_llm'
983-
plugin_lib_path = (
984-
"./tensorrt_llm" + "libnvinfer_plugin_tensorrt_llm.so"
985-
)
1039+
platform = str(platform).lower()
1040+
plugin_lib_path = download_plugin_lib_path(py_version, platform)
9861041
try:
987-
# Load the shared library
1042+
# Load the shared
1043+
install_mpi(platform)
9881044
handle = ctypes.CDLL(plugin_lib_path)
9891045
_LOGGER.info(f"Successfully loaded plugin library: {plugin_lib_path}")
9901046
except OSError as e_os_error:

tests/py/dynamo/conversion/harness.py

+13
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,7 @@ def generate_graph(
351351
enable_passes: bool,
352352
propagate_shapes: bool = False,
353353
settings: CompilationSettings = CompilationSettings(),
354+
fuse_distributed_ops: bool = False,
354355
torch_export_dynamic_shapes: Optional[Any] = None,
355356
):
356357
mod = mod.eval()
@@ -366,6 +367,16 @@ def generate_graph(
366367
tuple(torch_export_inputs),
367368
dynamic_shapes=torch_export_dynamic_shapes,
368369
)
370+
if fuse_distributed_ops:
371+
from torch_tensorrt.dynamo.lowering.passes.fuse_distributed_ops import (
372+
fuse_distributed_ops,
373+
)
374+
375+
gm = exported_program.graph_module
376+
gm = fuse_distributed_ops(gm, settings)
377+
exported_program = exported_program.run_decompositions(
378+
get_decompositions(False)
379+
)
369380
if enable_passes:
370381
exported_program = pre_export_lowering(exported_program, settings)
371382
exported_program = exported_program.run_decompositions(
@@ -404,6 +415,7 @@ def run_test(
404415
propagate_shapes=False,
405416
int32_reqd=False,
406417
immutable_weights=True,
418+
fuse_distributed_ops=False,
407419
):
408420
# TODO: lan to remove this and set use_dynamo_traccer to True by default
409421
# once all the converter test files are moved to use_dynamo_tracer
@@ -424,6 +436,7 @@ def run_test(
424436
enable_passes=enable_passes,
425437
propagate_shapes=propagate_shapes,
426438
settings=compilation_settings,
439+
fuse_distributed_ops=fuse_distributed_ops,
427440
)
428441

429442
num_inputs = len(inputs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import os
2+
3+
import torch
4+
import torch.distributed as dist
5+
import torch.nn as nn
6+
from parameterized import parameterized
7+
from torch.testing._internal.common_utils import run_tests
8+
9+
10+
def set_environment_variables():
11+
os.environ["WORLD_SIZE"] = str(1)
12+
os.environ["RANK"] = str(0)
13+
os.environ["MASTER_ADDR"] = "127.0.0.1"
14+
os.environ["MASTER_PORT"] = str(29500)
15+
os.environ["USE_TRTLLM_PLUGINS"] = "1"
16+
17+
18+
set_environment_variables()
19+
dist.init_process_group(backend="nccl", init_method="env://")
20+
group = dist.new_group(ranks=[0])
21+
group_name = group.group_name
22+
23+
from .harness import DispatchTestCase
24+
25+
26+
class TestGatherNcclOpsConverter(DispatchTestCase):
27+
@parameterized.expand([(8)])
28+
def test_nccl_ops(self, linear_layer_dim):
29+
class DistributedGatherModel(nn.Module):
30+
def __init__(self, input_dim):
31+
super().__init__()
32+
self.fc = torch.nn.Linear(input_dim, input_dim)
33+
34+
def forward(self, x):
35+
x = self.fc(x)
36+
world_size = 1
37+
gathered_tensor = torch.ops._c10d_functional.all_gather_into_tensor(
38+
x, world_size, group_name
39+
)
40+
gathered_tensor = torch.ops._c10d_functional.wait_tensor(
41+
gathered_tensor
42+
)
43+
return gathered_tensor
44+
45+
inputs = [torch.randn(1, linear_layer_dim).to("cuda")]
46+
47+
self.run_test(
48+
DistributedGatherModel(linear_layer_dim).cuda(),
49+
inputs,
50+
use_dynamo_tracer=True,
51+
fuse_distributed_ops=True,
52+
)
53+
54+
# TODO: Look at this
55+
# @parameterized.expand(
56+
# [
57+
# (8)
58+
# ]
59+
# )
60+
# def test_nccl_ops_scatter(self, linear_layer_dim):
61+
62+
# class DistributedReduceScatterModel(nn.Module):
63+
# def __init__(self, input_dim):
64+
# super().__init__()
65+
# def forward(self, x):
66+
# world_size = 1
67+
# scatter_reduce_tensor = torch.ops._c10d_functional.reduce_scatter_tensor(x, "sum", world_size, group_name)
68+
# scatter_reduce_tensor = torch.ops._c10d_functional.wait_tensor(scatter_reduce_tensor)
69+
# return scatter_reduce_tensor
70+
# inputs = [torch.zeros(1, linear_layer_dim).to("cuda")]
71+
72+
# self.run_test(
73+
# DistributedReduceScatterModel(linear_layer_dim).cuda(),
74+
# inputs,
75+
# use_dynamo_tracer=True,
76+
# )
77+
78+
79+
if __name__ == "__main__":
80+
run_tests()

0 commit comments

Comments
 (0)