Skip to content

Commit

Permalink
support dds and nonzero op in _PythonTorchTensorRTModule
Browse files Browse the repository at this point in the history
  • Loading branch information
zewenli98 committed Jan 23, 2025
1 parent 43831dc commit 2ee9299
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 9 deletions.
8 changes: 8 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class TRTInterpreterResult(NamedTuple):
input_names: Sequence[str]
output_names: Sequence[str]
weight_name_map: Optional[dict[Any, Any]]
output_shapes: Optional[Sequence[Tuple[int]]]


class TRTInterpreter(torch.fx.Interpreter): # type: ignore[misc]
Expand Down Expand Up @@ -132,6 +133,7 @@ def __init__(
# Mapping of constants to shapes and dtypes
self.const_mapping: Dict[str, Tuple[Sequence[int], str]] = {}
self.weight_name_map: Optional[Dict[str, Any]] = None
self.output_shapes: Sequence[Tuple[int]] = []

# Engine cache for storing and reusing TRT engines
self.engine_cache = engine_cache
Expand Down Expand Up @@ -651,6 +653,7 @@ def _pull_cached_engine(self, hash_val: str) -> Optional[TRTInterpreterResult]:
self._input_names,
self._output_names,
self.weight_name_map,
self.output_shapes if self.output_shapes else None,
)
return None

Expand Down Expand Up @@ -731,11 +734,16 @@ def run(
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()

for node in self.module.graph.nodes:
if node.op == "output":
self.output_shapes.append(tuple(node.meta["tensor_meta"].shape))

return TRTInterpreterResult(
engine_str,
self._input_names,
self._output_names,
self.weight_name_map,
self.output_shapes if self.output_shapes else None,
)

def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
Expand Down
17 changes: 17 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3580,3 +3580,20 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.nonzero(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,3 +625,18 @@ def native_dropout(
mask = np.ones(input_val.shape, dtype=bool)
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
return identity_layer.get_output(0), mask


def nonzero(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
non_zero_layer = ctx.net.add_non_zero(input_val)
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
shuffle_layer.first_transpose = trt.Permutation([1, 0])
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
return shuffle_layer.get_output(0)
24 changes: 15 additions & 9 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
serialized_engine: Optional[bytes] = None,
input_binding_names: Optional[List[str]] = None,
output_binding_names: Optional[List[str]] = None,
output_shapes: Optional[Sequence[Tuple[int]]] = None,
*,
name: str = "",
settings: CompilationSettings = CompilationSettings(),
Expand All @@ -100,6 +101,7 @@ def __init__(
serialized_engine (bytes): Serialized TensorRT engine in the form of a bytearray
input_binding_names (List[str]): List of input TensorRT engine binding names in the order they would be passed to the TRT modules
output_binding_names (List[str]): List of output TensorRT engine binding names in the order they should be returned
output_shapes (Sequence[Tuple]): List of output shapes for the engine. For some cases, output shapes are dynamic and depends on input data, like NonZero op, so we need to explicitly provide output shapes
Keyword Arguments:
name (str): Name for module
Expand Down Expand Up @@ -147,6 +149,7 @@ def __init__(
self.output_names = (
output_binding_names if output_binding_names is not None else []
)
self.output_shapes = output_shapes
self.initialized = False
self.target_device_id = (
settings.device.gpu_id
Expand Down Expand Up @@ -233,10 +236,12 @@ def setup_engine(self) -> None:
dtype._from(self.engine.get_tensor_dtype(output_name)).to(torch.dtype)
for output_name in self.output_names
]
self.output_shapes = [
self.engine.get_tensor_shape(output_name)
for output_name in self.output_names
]

if self.output_shapes is None:
self.output_shapes = [
self.engine.get_tensor_shape(output_name)
for output_name in self.output_names
]

if torch_tensorrt.runtime.get_cudagraphs_mode():
self.cudagraph = torch.cuda.CUDAGraph()
Expand Down Expand Up @@ -345,7 +350,7 @@ def setup_input_tensors(
def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []

assert self.output_shapes is not None
for o, _ in enumerate(self.output_names):
output = torch.empty(
size=self.output_shapes[o],
Expand Down Expand Up @@ -455,10 +460,11 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
if self.output_shapes is None:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
Expand Down
2 changes: 2 additions & 0 deletions tests/py/dynamo/conversion/harness.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def run_test(
serialized_engine=interpreter_result.serialized_engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
output_shapes=list(interpreter_result.output_shapes),
name="test_engine",
)
mod = mod.cuda()
Expand Down Expand Up @@ -288,6 +289,7 @@ def run_test_custom_compare_results(
serialized_engine=interpreter_result.serialized_engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
output_shapes=list(interpreter_result.output_shapes),
name="test_engine",
)
res_trt = trt_mod(*cuda_inputs).cpu()
Expand Down
33 changes: 33 additions & 0 deletions tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestAtanConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.int),
((1, 20), torch.int32),
((5, 3), torch.int64),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_atan_float(self, input_shape, dtype):
class atan(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)

inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
self.run_test(
atan(),
inputs,
propagate_shapes=True, # it requires propagate_shapes=True to get output shapes
)


if __name__ == "__main__":
run_tests()

0 comments on commit 2ee9299

Please sign in to comment.