Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 35 additions & 5 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
@@ -1550,45 +1550,70 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _LazyStackedTensorDictKeysView:
keys = _LazyStackedTensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
return keys

def values(self, include_nested=False, leaves_only=False, is_leaf=None):
def values(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().values(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for td in self.tensordicts:
yield from td.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def items(self, include_nested=False, leaves_only=False, is_leaf=None):
def items(
self,
include_nested=False,
leaves_only=False,
is_leaf=None,
*,
sort: bool = False,
):
if is_leaf not in (
_NESTED_TENSORS_AS_LISTS,
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
):
yield from super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)
else:
for i, td in enumerate(self.tensordicts):
for key, val in td.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
):
if isinstance(key, str):
key = (str(i), key)
@@ -3381,9 +3406,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def _select(
149 changes: 83 additions & 66 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
@@ -3144,12 +3144,23 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
if not include_nested and not leaves_only and is_leaf is None:
return _StringKeys(self._tensordict.keys())
if not sort:
return _StringKeys(self._tensordict.keys())
else:
return sorted(
_StringKeys(self._tensordict.keys()),
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
)
else:
return self._nested_keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

@cache # noqa: B019
@@ -3158,12 +3169,15 @@ def _nested_keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return _TensorDictKeysView(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

# some custom methods for efficiency
@@ -3172,81 +3186,68 @@ def items(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.items()
elif include_nested and leaves_only:
if not sort:
return self._tensordict.items()
return sorted(self._tensordict.items(), key=lambda x: x[0])
elif include_nested and leaves_only and not sort:
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
result = []
if is_dynamo_compiling():

def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
result.append(
(
(
key,
*(
(_key,)
if isinstance(_key, str)
else _key
),
),
_val,
)
)
else:
result.append((key, val))
return result

else:
# dynamo doesn't like generators
def fast_iter():
for key, val in self._tensordict.items():
if not is_leaf(type(val)):
yield from (
(
(
key,
*((_key,) if isinstance(_key, str) else _key),
),
_val,
)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
else:
yield (key, val)
def fast_iter():
for key, val in self._tensordict.items():
# We could easily make this faster, here we're iterating twice over the keys,
# but we could iterate just once.
# Ideally we should make a "dirty" list of items then call unravel_key on all of them.
if not is_leaf(type(val)):
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
):
if isinstance(_key, str):
_key = (key, _key)
else:
_key = (key, *_key)
result.append((_key, _val))
else:
result.append((key, val))
return result

return fast_iter()
else:
return super().items(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def values(
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]:
if not include_nested and not leaves_only:
return self._tensordict.values()
if not sort:
return self._tensordict.values()
else:
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
1
]
else:
return TensorDictBase.values(
self,
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)


@@ -3535,9 +3536,14 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _TensorDictKeysView:
return self._source.keys(
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def entry_class(self, key: NestedKey) -> type:
@@ -4172,29 +4178,40 @@ def __init__(
include_nested: bool,
leaves_only: bool,
is_leaf: Callable[[Type], bool] = None,
sort: bool = False,
) -> None:
self.tensordict = tensordict
self.include_nested = include_nested
self.leaves_only = leaves_only
if is_leaf is None:
is_leaf = _default_is_leaf
self.is_leaf = is_leaf
self.sort = sort

def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
def _iter():
if not self.include_nested:
if self.leaves_only:
for key in self._keys():
target_class = self.tensordict.entry_class(key)
if _is_tensor_collection(target_class):
continue
yield key
else:
yield from self._keys()
else:
yield from self._keys()
else:
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
yield from (
key if len(key) > 1 else key[0]
for key in self._iter_helper(self.tensordict)
)

if self.sort:
yield from sorted(
_iter(),
key=lambda key: ".".join(key) if isinstance(key, tuple) else key,
)
else:
yield from _iter()

def _iter_helper(
self, tensordict: T, prefix: str | None = None
177 changes: 113 additions & 64 deletions tensordict/base.py
Original file line number Diff line number Diff line change
@@ -5135,8 +5135,13 @@ def setdefault(
return self.get(key)

def items(
self, include_nested: bool = False, leaves_only: bool = False, is_leaf=None
) -> Iterator[tuple[str, CompatibleType]]:
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf=None,
*,
sort: bool = False,
) -> Iterator[tuple[str, CompatibleType]]: # noqa: D417
"""Returns a generator of key-value pairs for the tensordict.
Args:
@@ -5147,46 +5152,64 @@ def items(
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
Keyword Args:
sort (bool, optional): whether the keys should be sorted. For nested keys,
the keys are sorted according to their joined name (ie, ``("a", "key")`` will
be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur
significant overhead when dealing with large tensordicts.
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf

# check the conditions once only
if include_nested and leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
def _items():
if include_nested and leaves_only:
# check the conditions once only
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
)
else:
else:
yield k, val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield k, val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield k, val
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
if not is_leaf(type(val)):
yield from (
(_unravel_key_to_tuple((k, _key)), _val)
for _key, _val in val.items(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
)
)
elif leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if is_leaf(type(val)):
yield k, val
elif leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if is_leaf(type(val)):
yield k, val
else:
for k in self.keys():
yield k, self._get_str(k, NO_DEFAULT)

if sort:
yield from sorted(
_items(),
key=lambda item: (
item[0] if isinstance(item[0], str) else ".".join(item[0])
),
)
else:
for k in self.keys():
yield k, self._get_str(k, NO_DEFAULT)
yield from _items()

def non_tensor_items(self, include_nested: bool = False):
"""Returns all non-tensor leaves, maybe recursively."""
@@ -5203,7 +5226,9 @@ def values(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf=None,
) -> Iterator[CompatibleType]:
*,
sort: bool = False,
) -> Iterator[CompatibleType]: # noqa: D417
"""Returns a generator representing the values for the tensordict.
Args:
@@ -5214,39 +5239,54 @@ def values(
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
Keyword Args:
sort (bool, optional): whether the keys should be sorted. For nested keys,
the keys are sorted according to their joined name (ie, ``("a", "key")`` will
be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur
significant overhead when dealing with large tensordicts.
Defaults to ``False``.
"""
if is_leaf is None:
is_leaf = _default_is_leaf
# check the conditions once only
if include_nested and leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
else:
yield val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield val
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
elif leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if is_leaf(type(val)):

def _values():
# check the conditions once only
if include_nested and leaves_only:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
else:
yield val
elif include_nested:
for k in self.keys():
val = self._get_str(k, NO_DEFAULT)
yield val
if not is_leaf(type(val)):
yield from val.values(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
)
elif leaves_only:
for k in self.keys(sort=sort):
val = self._get_str(k, NO_DEFAULT)
if is_leaf(type(val)):
yield val
else:
for k in self.keys(sort=sort):
yield self._get_str(k, NO_DEFAULT)

if not sort or not include_nested:
yield from _values()
else:
for k in self.keys():
yield self._get_str(k, NO_DEFAULT)
for _, value in self.items(include_nested, leaves_only, is_leaf, sort=sort):
yield value

@cache # noqa: B019
def _values_list(
@@ -5344,7 +5384,9 @@ def keys(
self,
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] = None,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
):
"""Returns a generator of tensordict keys.
@@ -5356,6 +5398,13 @@ def keys(
is_leaf: an optional callable that indicates if a class is to be considered a
leaf or not.
Keyword Args:
sort (bool, optional): whether the keys shoulbe sorted. For nested keys,
the keys are sorted according to their joined name (ie, ``("a", "key")`` will
be counted as ``"a.key"`` for sorting). Be mindful that sorting may incur
significant overhead when dealing with large tensordicts.
Defaults to ``False``.
Examples:
>>> from tensordict import TensorDict
>>> data = TensorDict({"0": 0, "1": {"2": 2}}, batch_size=[])
8 changes: 6 additions & 2 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
@@ -1014,10 +1014,12 @@ def values(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[CompatibleType]:
if is_leaf is None:
is_leaf = _default_is_leaf
for v in self._param_td.values(include_nested, leaves_only):
for v in self._param_td.values(include_nested, leaves_only, sort=sort):
if not is_leaf(type(v)):
yield v
continue
@@ -1082,10 +1084,12 @@ def items(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> Iterator[CompatibleType]:
if is_leaf is None:
is_leaf = _default_is_leaf
for k, v in self._param_td.items(include_nested, leaves_only):
for k, v in self._param_td.items(include_nested, leaves_only, sort=sort):
if not is_leaf(type(v)):
yield k, v
continue
3 changes: 3 additions & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
@@ -471,6 +471,8 @@ def keys(
include_nested: bool = False,
leaves_only: bool = False,
is_leaf: Callable[[Type], bool] | None = None,
*,
sort: bool = False,
) -> _PersistentTDKeysView:
if is_leaf not in (None, _default_is_leaf, _is_leaf_nontensor):
raise ValueError(
@@ -481,6 +483,7 @@ def keys(
include_nested=include_nested,
leaves_only=leaves_only,
is_leaf=is_leaf,
sort=sort,
)

def _items_metadata(self, include_nested=False, leaves_only=False):
31 changes: 31 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
@@ -2239,6 +2239,37 @@ def assert_shared(td0):
td0 = td0.squeeze(0)
assert_shared(td0)

def test_sorted_keys(self):
td = TensorDict(
{
"a": {"b": 0, "c": 1},
"d": 2,
"e": {"f": 3, "g": {"h": 4, "i": 5}, "j": 6},
}
)
tdflat = td.flatten_keys()
tdflat["d"] = tdflat.pop("d")
tdflat["a.b"] = tdflat.pop("a.b")
for key1, key2 in zip(
td.keys(True, True, sort=True), tdflat.keys(True, True, sort=True)
):
if isinstance(key1, str):
assert key1 == key2
else:
assert ".".join(key1) == key2
for v1, v2 in zip(
td.values(True, True, sort=True), tdflat.values(True, True, sort=True)
):
assert v1 == v2
for (k1, v1), (k2, v2) in zip(
td.items(True, True, sort=True), tdflat.items(True, True, sort=True)
):
if isinstance(k1, str):
assert k1 == k2
else:
assert ".".join(k1) == k2
assert v1 == v2

def test_split_keys(self):
td = TensorDict(
{