diff --git a/torch2trt/converters/__init__.py b/torch2trt/converters/__init__.py index ffce493f..88a80059 100644 --- a/torch2trt/converters/__init__.py +++ b/torch2trt/converters/__init__.py @@ -32,6 +32,7 @@ from .div import * from .einsum import * from .expand import * +from .expand_as import * from .example_plugin import * from .floordiv import * from .gelu import * diff --git a/torch2trt/converters/expand_as.py b/torch2trt/converters/expand_as.py new file mode 100644 index 00000000..fea5bd06 --- /dev/null +++ b/torch2trt/converters/expand_as.py @@ -0,0 +1,52 @@ +from torch2trt.torch2trt import * +from torch2trt.module_test import add_module_test + + +@tensorrt_converter('torch.Tensor.expand_as') +def convert_expand_as(ctx): + input = ctx.method_args[0] + output = ctx.method_return + + inshape = tuple(input.shape)[1:] # exclude batch + shape = tuple(output.shape)[1:] + ndim = len(shape) + start = tuple([0]*ndim) + stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise + + layer = ctx.network.add_slice(input._trt, start, shape, stride) + + output._trt = layer.get_output(0) + + +class ExpandAsModule(torch.nn.Module): + def __init__(self, other: torch.Tensor): + super(ExpandAsModule, self).__init__() + self.other = other + + def forward(self, x: torch.Tensor): + return x.expand_as(self.other) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1),]) +def test_tensor_expand_as_scalar(): + return ExpandAsModule(torch.randn(3)) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 3, 3),]) +def test_tensor_expand_as_singledim(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 1, 3),]) +def test_tensor_expand_as_multidim(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 3, 3),]) +def test_tensor_expand_as_singledim_half(): + return ExpandAsModule(torch.randn((1, 3, 3, 3))) + + +@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 1, 3),]) +def test_tensor_expand_as_multidim_half(): + return ExpandAsModule(torch.randn((1, 3, 3, 3)))