|
4 | 4 | # LICENSE file in the root directory of this source tree.
|
5 | 5 | import argparse
|
6 | 6 | import contextlib
|
| 7 | +import importlib.util |
7 | 8 | import os
|
| 9 | +from pathlib import Path |
8 | 10 | from typing import Any
|
9 | 11 |
|
10 | 12 | import pytest
|
|
14 | 16 |
|
15 | 17 | from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
|
16 | 18 | from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
|
| 19 | +from torch.utils._pytree import tree_map |
17 | 20 |
|
18 | 21 | TORCH_VERSION = version.parse(torch.__version__).base_version
|
19 | 22 |
|
| 23 | +_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None |
| 24 | + |
| 25 | +_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0") |
| 26 | + |
20 | 27 |
|
21 | 28 | def test_vmap_compile():
|
22 | 29 | # Since we monkey patch vmap we need to make sure compile is happy with it
|
@@ -605,6 +612,33 @@ def remove_hidden(td):
|
605 | 612 | assert_close(module(td), module_compile(td))
|
606 | 613 | assert module_compile(td) is not td
|
607 | 614 |
|
| 615 | + def test_dispatch_nontensor(self, mode): |
| 616 | + torch._dynamo.reset_code_caches() |
| 617 | + |
| 618 | + # Non tensor |
| 619 | + x = torch.randn(3) |
| 620 | + y = None |
| 621 | + mod = Seq( |
| 622 | + Mod(lambda x, y: x[y, :], in_keys=["x", "y"], out_keys=["_z"]), |
| 623 | + Mod(lambda x, z: z * x, in_keys=["x", "_z"], out_keys=["out"]), |
| 624 | + ) |
| 625 | + assert mod(x=x, y=y)[-1].shape == torch.Size((1, 3)) |
| 626 | + mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode) |
| 627 | + torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) |
| 628 | + |
| 629 | + def test_dispatch_tensor(self, mode): |
| 630 | + torch._dynamo.reset_code_caches() |
| 631 | + |
| 632 | + x = torch.randn(3) |
| 633 | + y = torch.randn(3) |
| 634 | + mod = Seq( |
| 635 | + Mod(lambda x, y: x + y, in_keys=["x", "y"], out_keys=["z"]), |
| 636 | + Mod(lambda x, z: z * x, in_keys=["x", "z"], out_keys=["out"]), |
| 637 | + ) |
| 638 | + mod(x=x, y=y) |
| 639 | + mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode) |
| 640 | + torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y)) |
| 641 | + |
608 | 642 |
|
609 | 643 | @pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4")
|
610 | 644 | @pytest.mark.parametrize("mode", [None, "reduce-overhead"])
|
@@ -737,6 +771,101 @@ def call(x, td):
|
737 | 771 | assert (td_zero == 0).all()
|
738 | 772 |
|
739 | 773 |
|
| 774 | +@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5") |
| 775 | +class TestExport: |
| 776 | + def test_export_module(self): |
| 777 | + torch._dynamo.reset_code_caches() |
| 778 | + tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) |
| 779 | + x = torch.randn(3) |
| 780 | + y = torch.randn(3) |
| 781 | + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) |
| 782 | + assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all() |
| 783 | + |
| 784 | + def test_export_seq(self): |
| 785 | + torch._dynamo.reset_code_caches() |
| 786 | + tdm = Seq( |
| 787 | + Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), |
| 788 | + Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]), |
| 789 | + ) |
| 790 | + x = torch.randn(3) |
| 791 | + y = torch.randn(3) |
| 792 | + out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y}) |
| 793 | + torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y)) |
| 794 | + |
| 795 | + |
| 796 | +@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available") |
| 797 | +class TestONNXExport: |
| 798 | + def test_onnx_export_module(self, tmpdir): |
| 799 | + tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]) |
| 800 | + x = torch.randn(3) |
| 801 | + y = torch.randn(3) |
| 802 | + torch_input = {"x": x, "y": y} |
| 803 | + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) |
| 804 | + |
| 805 | + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) |
| 806 | + |
| 807 | + path = Path(tmpdir) / "file.onnx" |
| 808 | + onnx_program.save(str(path)) |
| 809 | + import onnxruntime |
| 810 | + |
| 811 | + ort_session = onnxruntime.InferenceSession( |
| 812 | + path, providers=["CPUExecutionProvider"] |
| 813 | + ) |
| 814 | + |
| 815 | + def to_numpy(tensor): |
| 816 | + return ( |
| 817 | + tensor.detach().cpu().numpy() |
| 818 | + if tensor.requires_grad |
| 819 | + else tensor.cpu().numpy() |
| 820 | + ) |
| 821 | + |
| 822 | + onnxruntime_input = { |
| 823 | + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) |
| 824 | + } |
| 825 | + |
| 826 | + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) |
| 827 | + torch.testing.assert_close( |
| 828 | + torch.as_tensor(onnxruntime_outputs[0]), tdm(x=x, y=y) |
| 829 | + ) |
| 830 | + |
| 831 | + def test_onnx_export_seq(self, tmpdir): |
| 832 | + tdm = Seq( |
| 833 | + Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]), |
| 834 | + Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]), |
| 835 | + ) |
| 836 | + x = torch.randn(3) |
| 837 | + y = torch.randn(3) |
| 838 | + torch_input = {"x": x, "y": y} |
| 839 | + torch.onnx.dynamo_export(tdm, x=x, y=y) |
| 840 | + onnx_program = torch.onnx.dynamo_export(tdm, **torch_input) |
| 841 | + |
| 842 | + onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input) |
| 843 | + |
| 844 | + path = Path(tmpdir) / "file.onnx" |
| 845 | + onnx_program.save(str(path)) |
| 846 | + import onnxruntime |
| 847 | + |
| 848 | + ort_session = onnxruntime.InferenceSession( |
| 849 | + path, providers=["CPUExecutionProvider"] |
| 850 | + ) |
| 851 | + |
| 852 | + def to_numpy(tensor): |
| 853 | + return ( |
| 854 | + tensor.detach().cpu().numpy() |
| 855 | + if tensor.requires_grad |
| 856 | + else tensor.cpu().numpy() |
| 857 | + ) |
| 858 | + |
| 859 | + onnxruntime_input = { |
| 860 | + k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input) |
| 861 | + } |
| 862 | + |
| 863 | + onnxruntime_outputs = ort_session.run(None, onnxruntime_input) |
| 864 | + torch.testing.assert_close( |
| 865 | + tree_map(torch.as_tensor, onnxruntime_outputs), tdm(x=x, y=y) |
| 866 | + ) |
| 867 | + |
| 868 | + |
740 | 869 | if __name__ == "__main__":
|
741 | 870 | args, unknown = argparse.ArgumentParser().parse_known_args()
|
742 | 871 | pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
|
0 commit comments