Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
use IOutputAllocator
Browse files Browse the repository at this point in the history
zewenli98 committed Jan 24, 2025
1 parent 43831dc commit 3b60296
Showing 4 changed files with 142 additions and 50 deletions.
18 changes: 17 additions & 1 deletion py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,6 @@
import numpy as np
import torch
from torch.fx.node import Argument, Node, Target

from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo._SourceIR import SourceIR
from torch_tensorrt.dynamo.conversion import impl
@@ -3580,3 +3579,20 @@ def aten_ops_full(
fill_value=args[1],
dtype=kwargs.get("dtype", None),
)


@dynamo_tensorrt_converter(torch.ops.aten.nonzero.default)
def aten_ops_nonzero(
ctx: ConversionContext,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
return impl.unary.nonzero(
ctx,
target,
SourceIR.ATEN,
name,
args[0],
)
15 changes: 15 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
Original file line number Diff line number Diff line change
@@ -625,3 +625,18 @@ def native_dropout(
mask = np.ones(input_val.shape, dtype=bool)
mask = get_trt_tensor(ctx, mask, f"{name}_mask")
return identity_layer.get_output(0), mask


def nonzero(
ctx: ConversionContext,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input_val: TRTTensor,
) -> TRTTensor:
non_zero_layer = ctx.net.add_non_zero(input_val)
set_layer_name(non_zero_layer, target, f"{name}_non_zero", source_ir)
shuffle_layer = ctx.net.add_shuffle(non_zero_layer.get_output(0))
shuffle_layer.first_transpose = trt.Permutation([1, 0])
set_layer_name(shuffle_layer, target, f"{name}_transpose", source_ir)
return shuffle_layer.get_output(0)
127 changes: 78 additions & 49 deletions py/torch_tensorrt/dynamo/runtime/_PythonTorchTensorRTModule.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,6 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import Platform, dtype
from torch_tensorrt.dynamo._settings import CompilationSettings
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM
from torch_tensorrt.logging import TRT_LOGGER
from torch_tensorrt.runtime._utils import (
_is_switch_required,
@@ -23,6 +22,42 @@
logger = logging.getLogger(__name__)


class OutputAllocator(trt.IOutputAllocator): # type: ignore[misc]
def __init__(self) -> None:
trt.IOutputAllocator.__init__(self)
self.buffers: Dict[str, torch.Tensor] = {}
self.shapes: Dict[str, Tuple[int, ...]] = {}

def reallocate_output(
self, tensor_name: str, memory: int, size: int, alignment: int
) -> Any:
shape = (size,)
if tensor_name not in self.buffers:
self.buffers[tensor_name] = torch.empty(
shape, dtype=torch.float, device=torch.cuda.current_device()
)
else:
self.buffers[tensor_name] = self.resize_or_reallocate(
self.buffers[tensor_name], shape
)
return self.data_ptr(self.buffers[tensor_name])

def notify_shape(self, tensor_name: str, shape: Tuple[int, ...]) -> None:
self.shapes[tensor_name] = tuple(shape)

def resize_or_reallocate(
self, buffer: torch.Tensor, shape: Tuple[int, ...]
) -> torch.Tensor:
if buffer.shape != shape:
buffer = torch.empty(
shape, dtype=torch.float, device=torch.cuda.current_device()
)
return buffer

def data_ptr(self, buffer: torch.Tensor) -> Any:
return buffer.data_ptr()


class TorchTRTRuntimeStates:
def __init__(self, new_cudagraphs: bool):
# Indicates whether CUDAGraphs were enabled in the previous execute_engine
@@ -147,6 +182,8 @@ def __init__(
self.output_names = (
output_binding_names if output_binding_names is not None else []
)
self.output_allocator = OutputAllocator()

self.initialized = False
self.target_device_id = (
settings.device.gpu_id
@@ -342,19 +379,6 @@ def setup_input_tensors(
input_name, contiguous_inputs[i].data_ptr()
)

def create_output_tensors(self) -> List[torch.Tensor]:
# create output tensors
outputs: List[torch.Tensor] = []

for o, _ in enumerate(self.output_names):
output = torch.empty(
size=self.output_shapes[o],
dtype=self.output_dtypes[o],
device=torch.cuda.current_device(),
)
outputs.append(output)
return outputs

def set_pre_allocated_outputs(self, enable: bool) -> None:
self.use_pre_allocated_outputs = enable

@@ -445,47 +469,18 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
This could happen if the input tensor addresses/shapes haven't been configured correctly"
)

with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
self.output_shapes = [
tuple(self.context.get_tensor_shape(output_name))
for output_name in self.output_names
]
if DYNAMIC_DIM in self.output_shapes:
raise ValueError(
"Encountered dynamic output shapes during runtime. This could mean the network has data-dependent output shapes which is not currently supported."
)
outputs = self.create_output_tensors()

for o, output_name in enumerate(self.output_names):

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
)
else:
self.context.set_tensor_address(
output_name, outputs[o].data_ptr()
)

with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:TensorRTRuntime"
)
if self.profiling_enabled
else nullcontext()
):
for output_name in self.output_names:
self.context.set_output_allocator(
output_name, self.output_allocator
)

self._caller_stream = torch.cuda.current_stream()
if (
self._engine_stream == torch.cuda.default_stream()
@@ -526,8 +521,42 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .

self._caller_stream.wait_stream(self._engine_stream)

with (
torch.autograd.profiler.record_function(
"PythonTorchTensorRTModule:ProcessOutputs"
)
if self.profiling_enabled
else nullcontext()
):
if can_use_pre_allocated_outputs:
outputs = self.pre_allocated_outputs
else:
outputs = []
for o, output_name in enumerate(self.output_names):
shape = self.output_allocator.shapes.get(output_name, None)
self.output_shapes[o] = shape
dtype = self.output_dtypes[o]
output = self.output_allocator.buffers.get(output_name, None)
prod = int(torch.prod(torch.tensor(shape)))
output = output.reshape(-1).view(dtype)[:prod].reshape(shape)
outputs.append(output)

for o, output_name in enumerate(self.output_names):

if need_cudagraphs_record:
self._output_buffers[o] = outputs[o].clone()

if cudagraphs_enabled:
self.context.set_tensor_address(
output_name, self._output_buffers[o].data_ptr()
)
else:
self.context.set_tensor_address(
output_name, outputs[o].data_ptr()
)

if self.use_pre_allocated_outputs:
self.pre_allocated_outputs = self.create_output_tensors()
self.pre_allocated_outputs = outputs

if cudagraphs_enabled:
for idx, o in enumerate(outputs):
32 changes: 32 additions & 0 deletions tests/py/dynamo/conversion/test_nonzero_aten.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torch.nn as nn
from parameterized import parameterized
from torch.testing._internal.common_utils import run_tests

from .harness import DispatchTestCase


class TestNonZeroConverter(DispatchTestCase):
@parameterized.expand(
[
((10,), torch.int),
((1, 20), torch.int32),
((2, 3), torch.int64),
((2, 3, 4), torch.float),
((2, 3, 4, 5), torch.float),
]
)
def test_non_zero_float(self, input_shape, dtype):
class NonZero(nn.Module):
def forward(self, input):
return torch.ops.aten.nonzero.default(input)

inputs = [torch.randint(low=0, high=3, size=input_shape, dtype=dtype)]
self.run_test(
NonZero(),
inputs,
)


if __name__ == "__main__":
run_tests()

0 comments on commit 3b60296

Please sign in to comment.