Skip to content

Commit e8fc1e0

Browse files
justinchubypytorchmergebot
authored andcommitted
[ONNX] New export logic leveraging ExportedProgram and ONNX IR (pytorch#132530)
1/n PR to - Move code from torch-onnx from commit justinchuby/torch-onnx@395495e into torch.onnx and fixes imports. - Integrate the new export logic with the torch.onnx.export API and include basic set of tests. - Refactor the API for the change. - Improve documentation. Next PRs will be more tests and docs. Fix pytorch#129277 Pull Request resolved: pytorch#132530 Approved by: https://github.com/titaiwangms, https://github.com/malfet
1 parent 06cc2e8 commit e8fc1e0

27 files changed

+5311
-346
lines changed

docs/source/onnx_torchscript.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -715,5 +715,5 @@ Classes
715715
:template: classtemplate.rst
716716

717717
JitScalarType
718-
torch.onnx.verification.GraphInfo
719-
torch.onnx.verification.VerificationOptions
718+
verification.GraphInfo
719+
verification.VerificationOptions

mypy.ini

+9-3
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,6 @@ ignore_missing_imports = True
165165
[mypy-tensorboard.*]
166166
ignore_missing_imports = True
167167

168-
[mypy-onnx.*]
169-
ignore_missing_imports = True
170-
171168
[mypy-matplotlib.*]
172169
ignore_missing_imports = True
173170

@@ -301,5 +298,14 @@ ignore_missing_imports = True
301298
# Third party dependencies that are optional.
302299
#
303300

301+
[mypy-onnx.*]
302+
ignore_missing_imports = True
303+
304+
[mypy-onnxruntime.*]
305+
ignore_missing_imports = True
306+
307+
[mypy-onnxscript.*]
308+
ignore_missing_imports = True
309+
304310
[mypy-redis]
305311
ignore_missing_imports = True

test/onnx/dynamo/test_exporter_api.py

-217
Original file line numberDiff line numberDiff line change
@@ -163,222 +163,5 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i
163163
)
164164

165165

166-
class TestONNXExportWithDynamo(common_utils.TestCase):
167-
def test_args_normalization_with_no_kwargs(self):
168-
exported_program = torch.export.export(
169-
SampleModelTwoInputs(),
170-
(
171-
torch.randn(1, 1, 2),
172-
torch.randn(1, 1, 2),
173-
),
174-
)
175-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
176-
exported_program, torch.randn(1, 1, 2), torch.randn(1, 1, 2)
177-
)
178-
onnx_program_from_old_exporter = torch.onnx.export(
179-
SampleModelTwoInputs(),
180-
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
181-
dynamo=True,
182-
)
183-
self.assertEqual(
184-
onnx_program_from_new_exporter.model_proto,
185-
onnx_program_from_old_exporter.model_proto,
186-
)
187-
188-
def test_args_is_tensor_not_tuple(self):
189-
exported_program = torch.export.export(SampleModel(), (torch.randn(1, 1, 2),))
190-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
191-
exported_program, torch.randn(1, 1, 2)
192-
)
193-
onnx_program_from_old_exporter = torch.onnx.export(
194-
SampleModel(), torch.randn(1, 1, 2), dynamo=True
195-
)
196-
self.assertEqual(
197-
onnx_program_from_new_exporter.model_proto,
198-
onnx_program_from_old_exporter.model_proto,
199-
)
200-
201-
def test_args_normalization_with_kwargs(self):
202-
exported_program = torch.export.export(
203-
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
204-
)
205-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
206-
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
207-
)
208-
onnx_program_from_old_exporter = torch.onnx.export(
209-
SampleModelTwoInputs(),
210-
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
211-
dynamo=True,
212-
)
213-
self.assertEqual(
214-
onnx_program_from_new_exporter.model_proto,
215-
onnx_program_from_old_exporter.model_proto,
216-
)
217-
218-
def test_args_normalization_with_empty_dict_at_the_tail(self):
219-
exported_program = torch.export.export(
220-
SampleModelTwoInputs(), (torch.randn(1, 1, 2),), {"b": torch.randn(1, 1, 2)}
221-
)
222-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
223-
exported_program, torch.randn(1, 1, 2), b=torch.randn(1, 1, 2)
224-
)
225-
onnx_program_from_old_exporter = torch.onnx.export(
226-
SampleModelTwoInputs(),
227-
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
228-
dynamo=True,
229-
)
230-
self.assertEqual(
231-
onnx_program_from_new_exporter.model_proto,
232-
onnx_program_from_old_exporter.model_proto,
233-
)
234-
235-
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
236-
exported_program = torch.export.export(
237-
SampleModelForDynamicShapes(),
238-
(
239-
torch.randn(2, 2, 3),
240-
torch.randn(2, 2, 3),
241-
),
242-
dynamic_shapes={
243-
"x": {
244-
0: torch.export.Dim("customx_dim_0"),
245-
1: torch.export.Dim("customx_dim_1"),
246-
2: torch.export.Dim("customx_dim_2"),
247-
},
248-
"b": {
249-
0: torch.export.Dim("customb_dim_0"),
250-
1: torch.export.Dim("customb_dim_1"),
251-
2: torch.export.Dim("customb_dim_2"),
252-
},
253-
},
254-
)
255-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
256-
exported_program,
257-
torch.randn(2, 2, 3),
258-
b=torch.randn(2, 2, 3),
259-
)
260-
onnx_program_from_old_exporter = torch.onnx.export(
261-
SampleModelForDynamicShapes(),
262-
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
263-
dynamic_axes={
264-
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
265-
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
266-
},
267-
dynamo=True,
268-
)
269-
self.assertEqual(
270-
onnx_program_from_new_exporter.model_proto,
271-
onnx_program_from_old_exporter.model_proto,
272-
)
273-
274-
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
275-
exported_program = torch.export.export(
276-
SampleModelForDynamicShapes(),
277-
(
278-
torch.randn(2, 2, 3),
279-
torch.randn(2, 2, 3),
280-
),
281-
dynamic_shapes={
282-
"x": {
283-
0: torch.export.Dim("customx_dim_0"),
284-
1: torch.export.Dim("customx_dim_1"),
285-
2: torch.export.Dim("customx_dim_2"),
286-
},
287-
"b": {
288-
0: torch.export.Dim("customb_dim_0"),
289-
1: torch.export.Dim("customb_dim_1"),
290-
2: torch.export.Dim("customb_dim_2"),
291-
},
292-
},
293-
)
294-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
295-
exported_program,
296-
torch.randn(2, 2, 3),
297-
b=torch.randn(2, 2, 3),
298-
)
299-
onnx_program_from_old_exporter = torch.onnx.export(
300-
SampleModelForDynamicShapes(),
301-
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
302-
dynamic_axes={
303-
"x": [0, 1, 2],
304-
"b": [0, 1, 2],
305-
},
306-
dynamo=True,
307-
)
308-
self.assertEqual(
309-
onnx_program_from_new_exporter.model_proto,
310-
onnx_program_from_old_exporter.model_proto,
311-
)
312-
313-
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
314-
exported_program = torch.export.export(
315-
SampleModelForDynamicShapes(),
316-
(
317-
torch.randn(2, 2, 3),
318-
torch.randn(2, 2, 3),
319-
),
320-
dynamic_shapes={
321-
"x": None,
322-
"b": {
323-
0: torch.export.Dim("customb_dim_0"),
324-
1: torch.export.Dim("customb_dim_1"),
325-
2: torch.export.Dim("customb_dim_2"),
326-
},
327-
},
328-
)
329-
onnx_program_from_new_exporter = torch.onnx.dynamo_export(
330-
exported_program,
331-
torch.randn(2, 2, 3),
332-
b=torch.randn(2, 2, 3),
333-
)
334-
onnx_program_from_old_exporter = torch.onnx.export(
335-
SampleModelForDynamicShapes(),
336-
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
337-
dynamic_axes={
338-
"b": [0, 1, 2],
339-
},
340-
dynamo=True,
341-
)
342-
self.assertEqual(
343-
onnx_program_from_new_exporter.model_proto,
344-
onnx_program_from_old_exporter.model_proto,
345-
)
346-
347-
def test_dynamic_shapes_hit_constraints_in_dynamo(self):
348-
# SampleModelTwoInputs has constraints becuse of add of two inputs,
349-
# so the two input shapes are related.
350-
with self.assertRaisesRegex(
351-
torch._dynamo.exc.UserError,
352-
"Constraints violated",
353-
):
354-
_ = torch.onnx.export(
355-
SampleModelTwoInputs(),
356-
(torch.randn(2, 2, 3), torch.randn(2, 2, 3)),
357-
dynamic_axes={
358-
"x": {0: "x_dim_0", 1: "x_dim_1", 2: "x_dim_2"},
359-
"b": {0: "b_dim_0", 1: "b_dim_1", 2: "b_dim_2"},
360-
},
361-
dynamo=True,
362-
)
363-
364-
def test_saved_f_exists_after_export(self):
365-
with common_utils.TemporaryFileName(suffix=".onnx") as path:
366-
_ = torch.onnx.export(
367-
SampleModel(), torch.randn(1, 1, 2), path, dynamo=True
368-
)
369-
self.assertTrue(os.path.exists(path))
370-
371-
def test_raises_error_when_input_is_script_module(self):
372-
class ScriptModule(torch.jit.ScriptModule):
373-
def forward(self, x):
374-
return x
375-
376-
with self.assertRaisesRegex(
377-
TypeError,
378-
"Dynamo export does not support ScriptModule or ScriptFunction.",
379-
):
380-
_ = torch.onnx.export(ScriptModule(), torch.randn(1, 1, 2), dynamo=True)
381-
382-
383166
if __name__ == "__main__":
384167
common_utils.run_tests()

test/onnx/exporter/README.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Directory for all ExportedProgram exporter logic.

test/onnx/exporter/test_api.py

+120
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Owner(s): ["module: onnx"]
2+
"""Simple API tests for the ONNX exporter."""
3+
4+
from __future__ import annotations
5+
6+
import os
7+
8+
import torch
9+
from torch.onnx._internal import exporter
10+
from torch.testing._internal import common_utils
11+
12+
13+
class SampleModel(torch.nn.Module):
14+
def forward(self, x):
15+
y = x + 1
16+
z = y.relu()
17+
return (y, z)
18+
19+
20+
class SampleModelTwoInputs(torch.nn.Module):
21+
def forward(self, x, b):
22+
y = x + b
23+
z = y.relu()
24+
return (y, z)
25+
26+
27+
class SampleModelForDynamicShapes(torch.nn.Module):
28+
def forward(self, x, b):
29+
return x.relu(), b.sigmoid()
30+
31+
32+
class TestExportAPIDynamo(common_utils.TestCase):
33+
"""Tests for the ONNX exporter API when dynamo=True."""
34+
35+
def test_args_normalization_with_no_kwargs(self):
36+
onnx_program = torch.onnx.export(
37+
SampleModelTwoInputs(),
38+
(torch.randn(1, 1, 2), torch.randn(1, 1, 2)),
39+
dynamo=True,
40+
)
41+
assert onnx_program
42+
exporter.verify_onnx_program(onnx_program)
43+
44+
def test_args_normalization_with_kwargs(self):
45+
onnx_program = torch.onnx.export(
46+
SampleModelTwoInputs(),
47+
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
48+
dynamo=True,
49+
)
50+
assert onnx_program
51+
exporter.verify_onnx_program(onnx_program)
52+
53+
def test_args_normalization_with_empty_dict_at_the_tail(self):
54+
onnx_program = torch.onnx.export(
55+
SampleModelTwoInputs(),
56+
(torch.randn(1, 1, 2), {"b": torch.randn(1, 1, 2)}),
57+
dynamo=True,
58+
)
59+
assert onnx_program
60+
exporter.verify_onnx_program(onnx_program)
61+
62+
def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes(self):
63+
onnx_program = torch.onnx.export(
64+
SampleModelForDynamicShapes(),
65+
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
66+
dynamic_axes={
67+
"x": {0: "customx_dim_0", 1: "customx_dim_1", 2: "customx_dim_2"},
68+
"b": {0: "customb_dim_0", 1: "customb_dim_1", 2: "customb_dim_2"},
69+
},
70+
dynamo=True,
71+
)
72+
assert onnx_program
73+
exporter.verify_onnx_program(onnx_program)
74+
75+
def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names(self):
76+
onnx_program = torch.onnx.export(
77+
SampleModelForDynamicShapes(),
78+
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
79+
dynamic_axes={
80+
"x": [0, 1, 2],
81+
"b": [0, 1, 2],
82+
},
83+
dynamo=True,
84+
)
85+
assert onnx_program
86+
exporter.verify_onnx_program(onnx_program)
87+
88+
def test_dynamic_axes_supports_partial_dynamic_shapes(self):
89+
onnx_program = torch.onnx.export(
90+
SampleModelForDynamicShapes(),
91+
(torch.randn(2, 2, 3), {"b": torch.randn(2, 2, 3)}),
92+
dynamic_axes={
93+
"b": [0, 1, 2],
94+
},
95+
dynamo=True,
96+
)
97+
assert onnx_program
98+
exporter.verify_onnx_program(onnx_program)
99+
100+
def test_saved_f_exists_after_export(self):
101+
with common_utils.TemporaryFileName(suffix=".onnx") as path:
102+
_ = torch.onnx.export(
103+
SampleModel(), (torch.randn(1, 1, 2),), path, dynamo=True
104+
)
105+
self.assertTrue(os.path.exists(path))
106+
107+
def test_export_supports_script_module(self):
108+
class ScriptModule(torch.nn.Module):
109+
def forward(self, x):
110+
return x
111+
112+
onnx_program = torch.onnx.export(
113+
torch.jit.script(ScriptModule()), (torch.randn(1, 1, 2),), dynamo=True
114+
)
115+
assert onnx_program
116+
exporter.verify_onnx_program(onnx_program)
117+
118+
119+
if __name__ == "__main__":
120+
common_utils.run_tests()

0 commit comments

Comments
 (0)