@@ -163,222 +163,5 @@ def test_raise_from_diagnostic_warning_when_diagnostic_option_warning_as_error_i
163
163
)
164
164
165
165
166
- class TestONNXExportWithDynamo (common_utils .TestCase ):
167
- def test_args_normalization_with_no_kwargs (self ):
168
- exported_program = torch .export .export (
169
- SampleModelTwoInputs (),
170
- (
171
- torch .randn (1 , 1 , 2 ),
172
- torch .randn (1 , 1 , 2 ),
173
- ),
174
- )
175
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
176
- exported_program , torch .randn (1 , 1 , 2 ), torch .randn (1 , 1 , 2 )
177
- )
178
- onnx_program_from_old_exporter = torch .onnx .export (
179
- SampleModelTwoInputs (),
180
- (torch .randn (1 , 1 , 2 ), torch .randn (1 , 1 , 2 )),
181
- dynamo = True ,
182
- )
183
- self .assertEqual (
184
- onnx_program_from_new_exporter .model_proto ,
185
- onnx_program_from_old_exporter .model_proto ,
186
- )
187
-
188
- def test_args_is_tensor_not_tuple (self ):
189
- exported_program = torch .export .export (SampleModel (), (torch .randn (1 , 1 , 2 ),))
190
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
191
- exported_program , torch .randn (1 , 1 , 2 )
192
- )
193
- onnx_program_from_old_exporter = torch .onnx .export (
194
- SampleModel (), torch .randn (1 , 1 , 2 ), dynamo = True
195
- )
196
- self .assertEqual (
197
- onnx_program_from_new_exporter .model_proto ,
198
- onnx_program_from_old_exporter .model_proto ,
199
- )
200
-
201
- def test_args_normalization_with_kwargs (self ):
202
- exported_program = torch .export .export (
203
- SampleModelTwoInputs (), (torch .randn (1 , 1 , 2 ),), {"b" : torch .randn (1 , 1 , 2 )}
204
- )
205
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
206
- exported_program , torch .randn (1 , 1 , 2 ), b = torch .randn (1 , 1 , 2 )
207
- )
208
- onnx_program_from_old_exporter = torch .onnx .export (
209
- SampleModelTwoInputs (),
210
- (torch .randn (1 , 1 , 2 ), {"b" : torch .randn (1 , 1 , 2 )}),
211
- dynamo = True ,
212
- )
213
- self .assertEqual (
214
- onnx_program_from_new_exporter .model_proto ,
215
- onnx_program_from_old_exporter .model_proto ,
216
- )
217
-
218
- def test_args_normalization_with_empty_dict_at_the_tail (self ):
219
- exported_program = torch .export .export (
220
- SampleModelTwoInputs (), (torch .randn (1 , 1 , 2 ),), {"b" : torch .randn (1 , 1 , 2 )}
221
- )
222
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
223
- exported_program , torch .randn (1 , 1 , 2 ), b = torch .randn (1 , 1 , 2 )
224
- )
225
- onnx_program_from_old_exporter = torch .onnx .export (
226
- SampleModelTwoInputs (),
227
- (torch .randn (1 , 1 , 2 ), {"b" : torch .randn (1 , 1 , 2 )}),
228
- dynamo = True ,
229
- )
230
- self .assertEqual (
231
- onnx_program_from_new_exporter .model_proto ,
232
- onnx_program_from_old_exporter .model_proto ,
233
- )
234
-
235
- def test_dynamic_axes_enable_dynamic_shapes_with_fully_specified_axes (self ):
236
- exported_program = torch .export .export (
237
- SampleModelForDynamicShapes (),
238
- (
239
- torch .randn (2 , 2 , 3 ),
240
- torch .randn (2 , 2 , 3 ),
241
- ),
242
- dynamic_shapes = {
243
- "x" : {
244
- 0 : torch .export .Dim ("customx_dim_0" ),
245
- 1 : torch .export .Dim ("customx_dim_1" ),
246
- 2 : torch .export .Dim ("customx_dim_2" ),
247
- },
248
- "b" : {
249
- 0 : torch .export .Dim ("customb_dim_0" ),
250
- 1 : torch .export .Dim ("customb_dim_1" ),
251
- 2 : torch .export .Dim ("customb_dim_2" ),
252
- },
253
- },
254
- )
255
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
256
- exported_program ,
257
- torch .randn (2 , 2 , 3 ),
258
- b = torch .randn (2 , 2 , 3 ),
259
- )
260
- onnx_program_from_old_exporter = torch .onnx .export (
261
- SampleModelForDynamicShapes (),
262
- (torch .randn (2 , 2 , 3 ), {"b" : torch .randn (2 , 2 , 3 )}),
263
- dynamic_axes = {
264
- "x" : {0 : "customx_dim_0" , 1 : "customx_dim_1" , 2 : "customx_dim_2" },
265
- "b" : {0 : "customb_dim_0" , 1 : "customb_dim_1" , 2 : "customb_dim_2" },
266
- },
267
- dynamo = True ,
268
- )
269
- self .assertEqual (
270
- onnx_program_from_new_exporter .model_proto ,
271
- onnx_program_from_old_exporter .model_proto ,
272
- )
273
-
274
- def test_dynamic_axes_enable_dynamic_shapes_with_default_axe_names (self ):
275
- exported_program = torch .export .export (
276
- SampleModelForDynamicShapes (),
277
- (
278
- torch .randn (2 , 2 , 3 ),
279
- torch .randn (2 , 2 , 3 ),
280
- ),
281
- dynamic_shapes = {
282
- "x" : {
283
- 0 : torch .export .Dim ("customx_dim_0" ),
284
- 1 : torch .export .Dim ("customx_dim_1" ),
285
- 2 : torch .export .Dim ("customx_dim_2" ),
286
- },
287
- "b" : {
288
- 0 : torch .export .Dim ("customb_dim_0" ),
289
- 1 : torch .export .Dim ("customb_dim_1" ),
290
- 2 : torch .export .Dim ("customb_dim_2" ),
291
- },
292
- },
293
- )
294
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
295
- exported_program ,
296
- torch .randn (2 , 2 , 3 ),
297
- b = torch .randn (2 , 2 , 3 ),
298
- )
299
- onnx_program_from_old_exporter = torch .onnx .export (
300
- SampleModelForDynamicShapes (),
301
- (torch .randn (2 , 2 , 3 ), {"b" : torch .randn (2 , 2 , 3 )}),
302
- dynamic_axes = {
303
- "x" : [0 , 1 , 2 ],
304
- "b" : [0 , 1 , 2 ],
305
- },
306
- dynamo = True ,
307
- )
308
- self .assertEqual (
309
- onnx_program_from_new_exporter .model_proto ,
310
- onnx_program_from_old_exporter .model_proto ,
311
- )
312
-
313
- def test_dynamic_axes_supports_partial_dynamic_shapes (self ):
314
- exported_program = torch .export .export (
315
- SampleModelForDynamicShapes (),
316
- (
317
- torch .randn (2 , 2 , 3 ),
318
- torch .randn (2 , 2 , 3 ),
319
- ),
320
- dynamic_shapes = {
321
- "x" : None ,
322
- "b" : {
323
- 0 : torch .export .Dim ("customb_dim_0" ),
324
- 1 : torch .export .Dim ("customb_dim_1" ),
325
- 2 : torch .export .Dim ("customb_dim_2" ),
326
- },
327
- },
328
- )
329
- onnx_program_from_new_exporter = torch .onnx .dynamo_export (
330
- exported_program ,
331
- torch .randn (2 , 2 , 3 ),
332
- b = torch .randn (2 , 2 , 3 ),
333
- )
334
- onnx_program_from_old_exporter = torch .onnx .export (
335
- SampleModelForDynamicShapes (),
336
- (torch .randn (2 , 2 , 3 ), {"b" : torch .randn (2 , 2 , 3 )}),
337
- dynamic_axes = {
338
- "b" : [0 , 1 , 2 ],
339
- },
340
- dynamo = True ,
341
- )
342
- self .assertEqual (
343
- onnx_program_from_new_exporter .model_proto ,
344
- onnx_program_from_old_exporter .model_proto ,
345
- )
346
-
347
- def test_dynamic_shapes_hit_constraints_in_dynamo (self ):
348
- # SampleModelTwoInputs has constraints becuse of add of two inputs,
349
- # so the two input shapes are related.
350
- with self .assertRaisesRegex (
351
- torch ._dynamo .exc .UserError ,
352
- "Constraints violated" ,
353
- ):
354
- _ = torch .onnx .export (
355
- SampleModelTwoInputs (),
356
- (torch .randn (2 , 2 , 3 ), torch .randn (2 , 2 , 3 )),
357
- dynamic_axes = {
358
- "x" : {0 : "x_dim_0" , 1 : "x_dim_1" , 2 : "x_dim_2" },
359
- "b" : {0 : "b_dim_0" , 1 : "b_dim_1" , 2 : "b_dim_2" },
360
- },
361
- dynamo = True ,
362
- )
363
-
364
- def test_saved_f_exists_after_export (self ):
365
- with common_utils .TemporaryFileName (suffix = ".onnx" ) as path :
366
- _ = torch .onnx .export (
367
- SampleModel (), torch .randn (1 , 1 , 2 ), path , dynamo = True
368
- )
369
- self .assertTrue (os .path .exists (path ))
370
-
371
- def test_raises_error_when_input_is_script_module (self ):
372
- class ScriptModule (torch .jit .ScriptModule ):
373
- def forward (self , x ):
374
- return x
375
-
376
- with self .assertRaisesRegex (
377
- TypeError ,
378
- "Dynamo export does not support ScriptModule or ScriptFunction." ,
379
- ):
380
- _ = torch .onnx .export (ScriptModule (), torch .randn (1 , 1 , 2 ), dynamo = True )
381
-
382
-
383
166
if __name__ == "__main__" :
384
167
common_utils .run_tests ()
0 commit comments