@@ -774,26 +774,55 @@ def call(x, td):
774
774
775
775
776
776
@pytest .mark .skipif (not _v2_5 , reason = "Requires PT>=2.5" )
777
+ @pytest .mark .parametrize ("strict" , [False , True ])
777
778
class TestExport :
778
- def test_export_module (self ):
779
+ def test_export_module (self , strict ):
779
780
torch ._dynamo .reset_code_caches ()
780
781
tdm = Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ])
781
782
x = torch .randn (3 )
782
783
y = torch .randn (3 )
783
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
784
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
784
785
assert (out .module ()(x = x , y = y ) == tdm (x = x , y = y )).all ()
785
786
786
- def test_export_seq (self ):
787
+ def test_export_seq (self , strict ):
787
788
torch ._dynamo .reset_code_caches ()
788
789
tdm = Seq (
789
790
Mod (lambda x , y : x * y , in_keys = ["x" , "y" ], out_keys = ["z" ]),
790
791
Mod (lambda z , x : z + x , in_keys = ["z" , "x" ], out_keys = ["out" ]),
791
792
)
792
793
x = torch .randn (3 )
793
794
y = torch .randn (3 )
794
- out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y })
795
+ out = torch .export .export (tdm , args = (), kwargs = {"x" : x , "y" : y }, strict = strict )
795
796
torch .testing .assert_close (out .module ()(x = x , y = y ), tdm (x = x , y = y ))
796
797
798
+ def test_td_output (self , strict ):
799
+ class Test (torch .nn .Module ):
800
+ def forward (self , x : torch .Tensor , y : torch .Tensor ):
801
+ return TensorDict (
802
+ {
803
+ "x" : x ,
804
+ "y" : y ,
805
+ },
806
+ batch_size = x .shape [0 ],
807
+ )
808
+
809
+ test = Test ()
810
+ x , y = torch .zeros (2 , 100 ), torch .zeros (2 , 100 )
811
+ result = torch .export .export (
812
+ test ,
813
+ args = (x , y ),
814
+ strict = False ,
815
+ dynamic_shapes = {
816
+ "x" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
817
+ "y" : {0 : torch .export .Dim ("batch" ), 1 : torch .export .Dim ("time" )},
818
+ },
819
+ )
820
+ export_mod = result .module ()
821
+ x_new , y_new = torch .zeros (5 , 100 ), torch .zeros (5 , 100 )
822
+ export_test = export_mod (x_new , y_new )
823
+ eager_test = test (x_new , y_new )
824
+ assert (export_test == eager_test ).all ()
825
+
797
826
798
827
@pytest .mark .skipif (not _has_onnx , reason = "ONNX is not available" )
799
828
class TestONNXExport :
0 commit comments