Skip to content

Commit aed7cb1

Browse files
committed
test: support automatic plugin feature with different dimensions and add flashinfer.rmsnorm support test case
1 parent c8155f5 commit aed7cb1

File tree

6 files changed

+67
-14
lines changed

6 files changed

+67
-14
lines changed

Diff for: .github/workflows/build-test-linux.yml

+1
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ jobs:
143143
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml -n 4 conversion/
144144
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin.py
145145
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
146+
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/flashinfer_plugin.py
146147
popd
147148
148149
tests-py-dynamo-fe:

Diff for: py/requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ pybind11==2.6.2
55
torch>=2.7.0.dev,<2.8.0
66
torchvision>=0.22.0.dev,<0.23.0
77
--extra-index-url https://pypi.ngc.nvidia.com
8-
pyyaml
8+
pyyaml

Diff for: py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import logging
23
from types import FunctionType
34
from typing import Any, Callable, Tuple
@@ -130,16 +131,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
130131
output = torch_op(*fake_args, **kwargs)
131132

132133
# We assume that number of dimensions are the same in torch op
133-
shape_calc_fns = [None] * args[0].ndim
134-
for i in range(args[0].ndim):
135-
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
134+
shape_calc_fns = [None] * output.ndim
135+
136+
for i in range(output.ndim):
137+
input_node_expr = list(
138+
itertools.chain.from_iterable(
139+
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
140+
)
141+
)
142+
136143
shape_calc_fns[i] = lambdify(
137144
tuple(input_node_expr), output.shape[i].node.expr, "math"
138145
)
139146

140147
out_desc = tensor_args[0].like()
141148
for i in range(out_desc.ndim):
142-
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
149+
input_shape_expr = list(
150+
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
151+
)
152+
143153
if output.shape[i].node.expr is None:
144154
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
145155
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]

Diff for: tests/py/dynamo/automatic_plugin/test_automatic_plugin.py

-9
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,3 @@ def forward(self, lhs, rhs):
8181

8282
if __name__ == "__main__":
8383
run_tests()
84-
85-
# Example Usage
86-
# A = torch.full((64, 64), 2, device="cuda", dtype=torch.float)
87-
# B = torch.full((64, 64), 3, device="cuda", dtype=torch.float)
88-
89-
# C, D = torch.ops.torchtrt_ex.elementwise_add_mul.default(A, B)
90-
91-
# print("C (Addition):", C)
92-
# print("D (Multiplication):", D)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import flashinfer
2+
import torch
3+
import torch.nn as nn
4+
import torch_tensorrt
5+
from parameterized import parameterized
6+
from torch.testing._internal.common_utils import run_tests
7+
from torch_tensorrt._enums import dtype
8+
9+
from ..conversion.harness import DispatchTestCase
10+
11+
12+
@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
13+
def flashinfer_rmsnorm(
14+
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
15+
) -> torch.Tensor:
16+
return flashinfer.norm.rmsnorm(input, weight)
17+
18+
19+
@torch.library.register_fake("flashinfer::rmsnorm")
20+
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
21+
return input
22+
23+
24+
torch_tensorrt.dynamo.conversion.plugins.custom_op(
25+
"flashinfer::rmsnorm", supports_dynamic_shapes=True
26+
)
27+
28+
29+
class TestAutomaticPlugin(DispatchTestCase):
30+
@parameterized.expand(
31+
[
32+
((64, 64), (64,), torch.float16),
33+
((256, 256), (256,), torch.float16),
34+
]
35+
)
36+
def test_rmsnorm_float(self, input_shape, weight_shape, data_type):
37+
class rmsnorm(nn.Module):
38+
def forward(self, input, weight):
39+
return torch.ops.flashinfer.rmsnorm.default(input, weight)
40+
41+
inputs = [
42+
torch.randn(input_shape, device="cuda", dtype=data_type),
43+
torch.randn(weight_shape, device="cuda", dtype=data_type),
44+
]
45+
46+
self.run_test(rmsnorm(), inputs, precision=dtype.f16)
47+
48+
49+
if __name__ == "__main__":
50+
run_tests()

Diff for: tests/py/requirements.txt

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ pytest>=8.2.1
88
pytest-xdist>=3.6.1
99
pyyaml
1010
timm>=1.0.3
11+
flashiner-python
1112
transformers==4.40.2
1213
nvidia-modelopt[deploy,hf,torch]~=0.17.0
1314
--extra-index-url https://pypi.nvidia.com

0 commit comments

Comments
 (0)