Skip to content

Commit 9f6b899

Browse files
author
Vincent Moens
committed
[BugFix] Fix parsing integer batch size in AOT
ghstack-source-id: 73e7dd4 Pull Request resolved: #1004
1 parent 85b6b81 commit 9f6b899

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

tensordict/_td.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2061,7 +2061,7 @@ def _parse_batch_size(
20612061
source: T | dict | None,
20622062
batch_size: Sequence[int] | torch.Size | int | None = None,
20632063
) -> torch.Size:
2064-
ERR = "batch size was not specified when creating the TensorDict instance and it could not be retrieved from source."
2064+
ERR = "batch size {} was not specified when creating the TensorDict instance and it could not be retrieved from source."
20652065

20662066
if is_dynamo_compiling():
20672067
if isinstance(batch_size, torch.Size):
@@ -2072,22 +2072,22 @@ def _parse_batch_size(
20722072
return torch.Size(tuple(batch_size))
20732073
if batch_size is None:
20742074
return torch.Size([])
2075-
elif isinstance(batch_size, Number):
2075+
elif isinstance(batch_size, (Number, torch.SymInt)):
20762076
return torch.Size([batch_size])
20772077
elif isinstance(source, TensorDictBase):
20782078
return source.batch_size
2079-
raise ValueError()
2079+
raise ValueError(ERR.format(batch_size))
20802080

20812081
try:
20822082
return torch.Size(batch_size)
20832083
except Exception:
20842084
if batch_size is None:
20852085
return torch.Size([])
2086-
elif isinstance(batch_size, Number):
2086+
elif isinstance(batch_size, (Number, torch.SymInt)):
20872087
return torch.Size([batch_size])
20882088
elif isinstance(source, TensorDictBase):
20892089
return source.batch_size
2090-
raise ValueError(ERR)
2090+
raise ValueError(ERR.format(batch_size))
20912091

20922092
@property
20932093
def batch_dims(self) -> int:

test/test_compile.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -774,26 +774,55 @@ def call(x, td):
774774

775775

776776
@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
777+
@pytest.mark.parametrize("strict", [False, True])
777778
class TestExport:
778-
def test_export_module(self):
779+
def test_export_module(self, strict):
779780
torch._dynamo.reset_code_caches()
780781
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
781782
x = torch.randn(3)
782783
y = torch.randn(3)
783-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
784+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
784785
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()
785786

786-
def test_export_seq(self):
787+
def test_export_seq(self, strict):
787788
torch._dynamo.reset_code_caches()
788789
tdm = Seq(
789790
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
790791
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
791792
)
792793
x = torch.randn(3)
793794
y = torch.randn(3)
794-
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
795+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}, strict=strict)
795796
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))
796797

798+
def test_td_output(self, strict):
799+
class Test(torch.nn.Module):
800+
def forward(self, x: torch.Tensor, y: torch.Tensor):
801+
return TensorDict(
802+
{
803+
"x": x,
804+
"y": y,
805+
},
806+
batch_size=x.shape[0],
807+
)
808+
809+
test = Test()
810+
x, y = torch.zeros(2, 100), torch.zeros(2, 100)
811+
result = torch.export.export(
812+
test,
813+
args=(x, y),
814+
strict=False,
815+
dynamic_shapes={
816+
"x": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
817+
"y": {0: torch.export.Dim("batch"), 1: torch.export.Dim("time")},
818+
},
819+
)
820+
export_mod = result.module()
821+
x_new, y_new = torch.zeros(5, 100), torch.zeros(5, 100)
822+
export_test = export_mod(x_new, y_new)
823+
eager_test = test(x_new, y_new)
824+
assert (export_test == eager_test).all()
825+
797826

798827
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
799828
class TestONNXExport:

0 commit comments

Comments
 (0)