Skip to content

Commit 1b6f451

Browse files
committed
[Feature] torch.export and onnx compatibility
ghstack-source-id: d312fc1dee177275a73482210c1ecfbe73b04f9e Pull Request resolved: #991
1 parent 436d6e8 commit 1b6f451

File tree

5 files changed

+208
-15
lines changed

5 files changed

+208
-15
lines changed

tensordict/_td.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1987,7 +1987,7 @@ def from_dict(
19871987
)
19881988

19891989
batch_size_set = torch.Size(()) if batch_size is None else batch_size
1990-
input_dict = copy(input_dict)
1990+
input_dict = dict(input_dict)
19911991
for key, value in list(input_dict.items()):
19921992
if isinstance(value, (dict,)):
19931993
# we don't know if another tensor of smaller size is coming

tensordict/base.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -390,13 +390,7 @@ def __getitem__(self, index: IndexType) -> Any:
390390
# _unravel_key_to_tuple will return an empty tuple if the index isn't a NestedKey
391391
idx_unravel = _unravel_key_to_tuple(index)
392392
if idx_unravel:
393-
result = self._get_tuple(idx_unravel, NO_DEFAULT)
394-
if is_non_tensor(result):
395-
result_data = getattr(result, "data", NO_DEFAULT)
396-
if result_data is NO_DEFAULT:
397-
return result.tolist()
398-
return result_data
399-
return result
393+
return self._get_tuple_maybe_non_tensor(idx_unravel, NO_DEFAULT)
400394

401395
if (istuple and not index) or (not istuple and index is Ellipsis):
402396
# empty tuple returns self
@@ -4669,6 +4663,15 @@ def _get_str(self, key, default): ...
46694663
@abc.abstractmethod
46704664
def _get_tuple(self, key, default): ...
46714665

4666+
def _get_tuple_maybe_non_tensor(self, key, default):
4667+
result = self._get_tuple(key, default)
4668+
if is_non_tensor(result):
4669+
result_data = getattr(result, "data", NO_DEFAULT)
4670+
if result_data is NO_DEFAULT:
4671+
return result.tolist()
4672+
return result_data
4673+
return result
4674+
46724675
def get_at(
46734676
self, key: NestedKey, index: IndexType, default: CompatibleType = NO_DEFAULT
46744677
) -> CompatibleType:
@@ -8549,25 +8552,34 @@ def _remove_batch_dim(self, vmap_level, batch_size, out_dim): ...
85498552
def _maybe_remove_batch_dim(self, funcname, vmap_level, batch_size, out_dim): ...
85508553

85518554
# Validation and checks
8552-
def _convert_to_tensor(self, array: np.ndarray) -> Tensor:
8555+
def _convert_to_tensor(
8556+
self, array: Any
8557+
) -> Tensor | "NonTensorData" | TensorDictBase: # noqa: F821
8558+
# We are sure that array is not a dict or anything in _ACCEPTED_CLASSES
8559+
castable = None
85538560
if isinstance(array, (float, int, bool)):
8554-
pass
8561+
castable = True
85558562
elif isinstance(array, np.ndarray) and array.dtype.names is not None:
85568563
return TensorDictBase.from_struct_array(array, device=self.device)
8564+
elif isinstance(array, np.ndarray):
8565+
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
85578566
elif isinstance(array, np.bool_):
8567+
castable = True
85588568
array = array.item()
8559-
elif isinstance(array, list):
8569+
elif isinstance(array, (list, tuple)):
85608570
array = np.asarray(array)
8571+
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
85618572
elif hasattr(array, "numpy"):
85628573
# tf.Tensor with no shape can't be converted otherwise
85638574
array = array.numpy()
8564-
try:
8575+
castable = array.dtype.kind in ("c", "i", "f", "b", "u")
8576+
if castable:
85658577
return torch.as_tensor(array, device=self.device)
8566-
except Exception:
8578+
else:
85678579
from tensordict.tensorclass import NonTensorData
85688580

85698581
return NonTensorData(
8570-
array,
8582+
data=array,
85718583
batch_size=self.batch_size,
85728584
device=self.device,
85738585
names=self._maybe_names(),
@@ -8624,6 +8636,7 @@ def _validate_value(
86248636
)
86258637
is_tc = True
86268638
elif not issubclass(cls, _ACCEPTED_CLASSES):
8639+
# If cls is not a tensor
86278640
try:
86288641
value = self._convert_to_tensor(value)
86298642
except ValueError as err:

tensordict/nn/common.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,53 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
324324
return func(_self, tensordict, *args, **kwargs)
325325
return func(tensordict, *args, **kwargs)
326326

327+
return self._update_func_signature(func, wrapper)
328+
329+
def _update_func_signature(self, func, wrapper):
330+
# Create a new signature with the desired parameters
331+
# Get the original function's signature
332+
orig_signature = inspect.signature(func)
333+
334+
# params = [inspect.Parameter(name='', kind=inspect.Parameter.VAR_POSITIONAL)]
335+
params = []
336+
i = -1
337+
for i, param in enumerate(orig_signature.parameters.values()):
338+
if param.kind in (
339+
inspect.Parameter.VAR_KEYWORD,
340+
inspect.Parameter.KEYWORD_ONLY,
341+
):
342+
i = i - 1
343+
break
344+
if param.default is inspect._empty:
345+
params.append(
346+
inspect.Parameter(
347+
name=param.name,
348+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
349+
default=None,
350+
)
351+
)
352+
else:
353+
params.append(param)
354+
355+
# Add the **kwargs parameter
356+
357+
# for key in self.get_source(func, self_func):
358+
if i >= 0:
359+
params.extend(list(orig_signature.parameters.values())[i + 1 :])
360+
elif i == -1:
361+
params.extend(list(orig_signature.parameters.values()))
362+
363+
# Update the wrapper's signature
364+
wrapper.__signature__ = inspect.Signature(params)
365+
327366
return wrapper
328367

368+
def get_source(self, func, self_func):
369+
source = self.source
370+
if isinstance(source, str):
371+
return getattr(self_func, source)
372+
return source
373+
329374

330375
class _OutKeysSelect:
331376
def __init__(self, out_keys):
@@ -1226,7 +1271,12 @@ def forward(
12261271
tensors = ()
12271272
else:
12281273
# TODO: v0.7: remove the None
1229-
tensors = tuple(tensordict.get(in_key, None) for in_key in self.in_keys)
1274+
tensors = tuple(
1275+
tensordict._get_tuple_maybe_non_tensor(
1276+
_unravel_key_to_tuple(in_key), None
1277+
)
1278+
for in_key in self.in_keys
1279+
)
12301280
try:
12311281
tensors = self._call_module(tensors, **kwargs)
12321282
except Exception as err:

tensordict/tensorclass.py

+1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def __subclasscheck__(self, subclass):
135135
"_get_names_idx", # no wrap output
136136
"_get_str",
137137
"_get_tuple",
138+
"_get_tuple_maybe_non_tensor",
138139
"_has_names",
139140
"_items_list",
140141
"_maybe_names",

test/test_compile.py

+129
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
# LICENSE file in the root directory of this source tree.
55
import argparse
66
import contextlib
7+
import importlib.util
78
import os
9+
from pathlib import Path
810
from typing import Any
911

1012
import pytest
@@ -14,9 +16,14 @@
1416

1517
from tensordict import assert_close, tensorclass, TensorDict, TensorDictParams
1618
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
19+
from torch.utils._pytree import tree_map
1720

1821
TORCH_VERSION = version.parse(torch.__version__).base_version
1922

23+
_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None
24+
25+
_v2_5 = version.parse(".".join(TORCH_VERSION.split(".")[:3])) >= version.parse("2.5.0")
26+
2027

2128
def test_vmap_compile():
2229
# Since we monkey patch vmap we need to make sure compile is happy with it
@@ -605,6 +612,33 @@ def remove_hidden(td):
605612
assert_close(module(td), module_compile(td))
606613
assert module_compile(td) is not td
607614

615+
def test_dispatch_nontensor(self, mode):
616+
torch._dynamo.reset_code_caches()
617+
618+
# Non tensor
619+
x = torch.randn(3)
620+
y = None
621+
mod = Seq(
622+
Mod(lambda x, y: x[y, :], in_keys=["x", "y"], out_keys=["_z"]),
623+
Mod(lambda x, z: z * x, in_keys=["x", "_z"], out_keys=["out"]),
624+
)
625+
assert mod(x=x, y=y)[-1].shape == torch.Size((1, 3))
626+
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
627+
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))
628+
629+
def test_dispatch_tensor(self, mode):
630+
torch._dynamo.reset_code_caches()
631+
632+
x = torch.randn(3)
633+
y = torch.randn(3)
634+
mod = Seq(
635+
Mod(lambda x, y: x + y, in_keys=["x", "y"], out_keys=["z"]),
636+
Mod(lambda x, z: z * x, in_keys=["x", "z"], out_keys=["out"]),
637+
)
638+
mod(x=x, y=y)
639+
mod_compile = torch.compile(mod, fullgraph=_v2_5, mode=mode)
640+
torch.testing.assert_close(mod(x=x, y=y), mod_compile(x=x, y=y))
641+
608642

609643
@pytest.mark.skipif(not (TORCH_VERSION > "2.4.0"), reason="requires torch>2.4")
610644
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
@@ -737,6 +771,101 @@ def call(x, td):
737771
assert (td_zero == 0).all()
738772

739773

774+
@pytest.mark.skipif(not _v2_5, reason="Requires PT>=2.5")
775+
class TestExport:
776+
def test_export_module(self):
777+
torch._dynamo.reset_code_caches()
778+
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
779+
x = torch.randn(3)
780+
y = torch.randn(3)
781+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
782+
assert (out.module()(x=x, y=y) == tdm(x=x, y=y)).all()
783+
784+
def test_export_seq(self):
785+
torch._dynamo.reset_code_caches()
786+
tdm = Seq(
787+
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
788+
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
789+
)
790+
x = torch.randn(3)
791+
y = torch.randn(3)
792+
out = torch.export.export(tdm, args=(), kwargs={"x": x, "y": y})
793+
torch.testing.assert_close(out.module()(x=x, y=y), tdm(x=x, y=y))
794+
795+
796+
@pytest.mark.skipif(not _has_onnx, reason="ONNX is not available")
797+
class TestONNXExport:
798+
def test_onnx_export_module(self, tmpdir):
799+
tdm = Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"])
800+
x = torch.randn(3)
801+
y = torch.randn(3)
802+
torch_input = {"x": x, "y": y}
803+
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)
804+
805+
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)
806+
807+
path = Path(tmpdir) / "file.onnx"
808+
onnx_program.save(str(path))
809+
import onnxruntime
810+
811+
ort_session = onnxruntime.InferenceSession(
812+
path, providers=["CPUExecutionProvider"]
813+
)
814+
815+
def to_numpy(tensor):
816+
return (
817+
tensor.detach().cpu().numpy()
818+
if tensor.requires_grad
819+
else tensor.cpu().numpy()
820+
)
821+
822+
onnxruntime_input = {
823+
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
824+
}
825+
826+
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
827+
torch.testing.assert_close(
828+
torch.as_tensor(onnxruntime_outputs[0]), tdm(x=x, y=y)
829+
)
830+
831+
def test_onnx_export_seq(self, tmpdir):
832+
tdm = Seq(
833+
Mod(lambda x, y: x * y, in_keys=["x", "y"], out_keys=["z"]),
834+
Mod(lambda z, x: z + x, in_keys=["z", "x"], out_keys=["out"]),
835+
)
836+
x = torch.randn(3)
837+
y = torch.randn(3)
838+
torch_input = {"x": x, "y": y}
839+
torch.onnx.dynamo_export(tdm, x=x, y=y)
840+
onnx_program = torch.onnx.dynamo_export(tdm, **torch_input)
841+
842+
onnx_input = onnx_program.adapt_torch_inputs_to_onnx(**torch_input)
843+
844+
path = Path(tmpdir) / "file.onnx"
845+
onnx_program.save(str(path))
846+
import onnxruntime
847+
848+
ort_session = onnxruntime.InferenceSession(
849+
path, providers=["CPUExecutionProvider"]
850+
)
851+
852+
def to_numpy(tensor):
853+
return (
854+
tensor.detach().cpu().numpy()
855+
if tensor.requires_grad
856+
else tensor.cpu().numpy()
857+
)
858+
859+
onnxruntime_input = {
860+
k.name: to_numpy(v) for k, v in zip(ort_session.get_inputs(), onnx_input)
861+
}
862+
863+
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)
864+
torch.testing.assert_close(
865+
tree_map(torch.as_tensor, onnxruntime_outputs), tdm(x=x, y=y)
866+
)
867+
868+
740869
if __name__ == "__main__":
741870
args, unknown = argparse.ArgumentParser().parse_known_args()
742871
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 commit comments

Comments
 (0)