Skip to content

Added converter for torch.Tensor.expand_as() method #614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torch2trt/converters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .div import *
from .einsum import *
from .expand import *
from .expand_as import *
from .example_plugin import *
from .floordiv import *
from .gelu import *
Expand Down
52 changes: 52 additions & 0 deletions torch2trt/converters/expand_as.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from torch2trt.torch2trt import *
from torch2trt.module_test import add_module_test


@tensorrt_converter('torch.Tensor.expand_as')
def convert_expand_as(ctx):
input = ctx.method_args[0]
output = ctx.method_return

inshape = tuple(input.shape)[1:] # exclude batch
shape = tuple(output.shape)[1:]
ndim = len(shape)
start = tuple([0]*ndim)
stride = tuple([int(i == o) for i, o in zip(inshape, shape)]) # stride == 1 if dimensions match, 0 otherwise

layer = ctx.network.add_slice(input._trt, start, shape, stride)

output._trt = layer.get_output(0)


class ExpandAsModule(torch.nn.Module):
def __init__(self, other: torch.Tensor):
super(ExpandAsModule, self).__init__()
self.other = other

def forward(self, x: torch.Tensor):
return x.expand_as(self.other)


@add_module_test(torch.float32, torch.device('cuda'), [(1),])
def test_tensor_expand_as_scalar():
return ExpandAsModule(torch.randn(3))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 3, 3),])
def test_tensor_expand_as_singledim():
return ExpandAsModule(torch.randn((1, 3, 3, 3)))


@add_module_test(torch.float32, torch.device('cuda'), [(1, 1, 1, 3),])
def test_tensor_expand_as_multidim():
return ExpandAsModule(torch.randn((1, 3, 3, 3)))


@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 3, 3),])
def test_tensor_expand_as_singledim_half():
return ExpandAsModule(torch.randn((1, 3, 3, 3)))


@add_module_test(torch.float16, torch.device('cuda'), [(1, 1, 1, 3),])
def test_tensor_expand_as_multidim_half():
return ExpandAsModule(torch.randn((1, 3, 3, 3)))