Skip to content

Commit 7422834

Browse files
author
Vincent Moens
committed
[Feature] sorted keys, values and items
ghstack-source-id: 624542b Pull Request resolved: #965
1 parent bf6f30e commit 7422834

File tree

6 files changed

+261
-137
lines changed

6 files changed

+261
-137
lines changed

tensordict/_lazy.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1507,45 +1507,70 @@ def keys(
15071507
include_nested: bool = False,
15081508
leaves_only: bool = False,
15091509
is_leaf: Callable[[Type], bool] | None = None,
1510+
*,
1511+
sort: bool = False,
15101512
) -> _LazyStackedTensorDictKeysView:
15111513
keys = _LazyStackedTensorDictKeysView(
15121514
self,
15131515
include_nested=include_nested,
15141516
leaves_only=leaves_only,
15151517
is_leaf=is_leaf,
1518+
sort=sort,
15161519
)
15171520
return keys
15181521

1519-
def values(self, include_nested=False, leaves_only=False, is_leaf=None):
1522+
def values(
1523+
self,
1524+
include_nested=False,
1525+
leaves_only=False,
1526+
is_leaf=None,
1527+
*,
1528+
sort: bool = False,
1529+
):
15201530
if is_leaf not in (
15211531
_NESTED_TENSORS_AS_LISTS,
15221532
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
15231533
):
15241534
yield from super().values(
1525-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
1535+
include_nested=include_nested,
1536+
leaves_only=leaves_only,
1537+
is_leaf=is_leaf,
1538+
sort=sort,
15261539
)
15271540
else:
15281541
for td in self.tensordicts:
15291542
yield from td.values(
15301543
include_nested=include_nested,
15311544
leaves_only=leaves_only,
15321545
is_leaf=is_leaf,
1546+
sort=sort,
15331547
)
15341548

1535-
def items(self, include_nested=False, leaves_only=False, is_leaf=None):
1549+
def items(
1550+
self,
1551+
include_nested=False,
1552+
leaves_only=False,
1553+
is_leaf=None,
1554+
*,
1555+
sort: bool = False,
1556+
):
15361557
if is_leaf not in (
15371558
_NESTED_TENSORS_AS_LISTS,
15381559
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
15391560
):
15401561
yield from super().items(
1541-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
1562+
include_nested=include_nested,
1563+
leaves_only=leaves_only,
1564+
is_leaf=is_leaf,
1565+
sort=sort,
15421566
)
15431567
else:
15441568
for i, td in enumerate(self.tensordicts):
15451569
for key, val in td.items(
15461570
include_nested=include_nested,
15471571
leaves_only=leaves_only,
15481572
is_leaf=is_leaf,
1573+
sort=sort,
15491574
):
15501575
if isinstance(key, str):
15511576
key = (str(i), key)

tensordict/_td.py

Lines changed: 80 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -3074,12 +3074,23 @@ def keys(
30743074
include_nested: bool = False,
30753075
leaves_only: bool = False,
30763076
is_leaf: Callable[[Type], bool] | None = None,
3077+
*,
3078+
sort: bool = False,
30773079
) -> _TensorDictKeysView:
30783080
if not include_nested and not leaves_only and is_leaf is None:
3079-
return _StringKeys(self._tensordict.keys())
3081+
if not sort:
3082+
return _StringKeys(self._tensordict.keys())
3083+
else:
3084+
return sorted(
3085+
_StringKeys(self._tensordict.keys()),
3086+
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
3087+
)
30803088
else:
30813089
return self._nested_keys(
3082-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3090+
include_nested=include_nested,
3091+
leaves_only=leaves_only,
3092+
is_leaf=is_leaf,
3093+
sort=sort,
30833094
)
30843095

30853096
@cache # noqa: B019
@@ -3088,12 +3099,15 @@ def _nested_keys(
30883099
include_nested: bool = False,
30893100
leaves_only: bool = False,
30903101
is_leaf: Callable[[Type], bool] | None = None,
3102+
*,
3103+
sort: bool = False,
30913104
) -> _TensorDictKeysView:
30923105
return _TensorDictKeysView(
30933106
self,
30943107
include_nested=include_nested,
30953108
leaves_only=leaves_only,
30963109
is_leaf=is_leaf,
3110+
sort=sort,
30973111
)
30983112

30993113
# some custom methods for efficiency
@@ -3102,81 +3116,68 @@ def items(
31023116
include_nested: bool = False,
31033117
leaves_only: bool = False,
31043118
is_leaf: Callable[[Type], bool] | None = None,
3119+
*,
3120+
sort: bool = False,
31053121
) -> Iterator[tuple[str, CompatibleType]]:
31063122
if not include_nested and not leaves_only:
3107-
return self._tensordict.items()
3108-
elif include_nested and leaves_only:
3123+
if not sort:
3124+
return self._tensordict.items()
3125+
return sorted(self._tensordict.items(), key=lambda x: x[0])
3126+
elif include_nested and leaves_only and not sort:
31093127
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
31103128
result = []
3111-
if is_dynamo_compiling():
3112-
3113-
def fast_iter():
3114-
for key, val in self._tensordict.items():
3115-
if not is_leaf(type(val)):
3116-
for _key, _val in val.items(
3117-
include_nested=include_nested,
3118-
leaves_only=leaves_only,
3119-
is_leaf=is_leaf,
3120-
):
3121-
result.append(
3122-
(
3123-
(
3124-
key,
3125-
*(
3126-
(_key,)
3127-
if isinstance(_key, str)
3128-
else _key
3129-
),
3130-
),
3131-
_val,
3132-
)
3133-
)
3134-
else:
3135-
result.append((key, val))
3136-
return result
31373129

3138-
else:
3139-
# dynamo doesn't like generators
3140-
def fast_iter():
3141-
for key, val in self._tensordict.items():
3142-
if not is_leaf(type(val)):
3143-
yield from (
3144-
(
3145-
(
3146-
key,
3147-
*((_key,) if isinstance(_key, str) else _key),
3148-
),
3149-
_val,
3150-
)
3151-
for _key, _val in val.items(
3152-
include_nested=include_nested,
3153-
leaves_only=leaves_only,
3154-
is_leaf=is_leaf,
3155-
)
3156-
)
3157-
else:
3158-
yield (key, val)
3130+
def fast_iter():
3131+
for key, val in self._tensordict.items():
3132+
# We could easily make this faster, here we're iterating twice over the keys,
3133+
# but we could iterate just once.
3134+
# Ideally we should make a "dirty" list of items then call unravel_key on all of them.
3135+
if not is_leaf(type(val)):
3136+
for _key, _val in val.items(
3137+
include_nested=include_nested,
3138+
leaves_only=leaves_only,
3139+
is_leaf=is_leaf,
3140+
):
3141+
if isinstance(_key, str):
3142+
_key = (key, _key)
3143+
else:
3144+
_key = (key, *_key)
3145+
result.append((_key, _val))
3146+
else:
3147+
result.append((key, val))
3148+
return result
31593149

31603150
return fast_iter()
31613151
else:
31623152
return super().items(
3163-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3153+
include_nested=include_nested,
3154+
leaves_only=leaves_only,
3155+
is_leaf=is_leaf,
3156+
sort=sort,
31643157
)
31653158

31663159
def values(
31673160
self,
31683161
include_nested: bool = False,
31693162
leaves_only: bool = False,
31703163
is_leaf: Callable[[Type], bool] | None = None,
3164+
*,
3165+
sort: bool = False,
31713166
) -> Iterator[tuple[str, CompatibleType]]:
31723167
if not include_nested and not leaves_only:
3173-
return self._tensordict.values()
3168+
if not sort:
3169+
return self._tensordict.values()
3170+
else:
3171+
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
3172+
1
3173+
]
31743174
else:
31753175
return TensorDictBase.values(
31763176
self,
31773177
include_nested=include_nested,
31783178
leaves_only=leaves_only,
31793179
is_leaf=is_leaf,
3180+
sort=sort,
31803181
)
31813182

31823183

@@ -3465,9 +3466,14 @@ def keys(
34653466
include_nested: bool = False,
34663467
leaves_only: bool = False,
34673468
is_leaf: Callable[[Type], bool] | None = None,
3469+
*,
3470+
sort: bool = False,
34683471
) -> _TensorDictKeysView:
34693472
return self._source.keys(
3470-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3473+
include_nested=include_nested,
3474+
leaves_only=leaves_only,
3475+
is_leaf=is_leaf,
3476+
sort=sort,
34713477
)
34723478

34733479
def entry_class(self, key: NestedKey) -> type:
@@ -4099,30 +4105,37 @@ def __init__(
40994105
include_nested: bool,
41004106
leaves_only: bool,
41014107
is_leaf: Callable[[Type], bool] = None,
4108+
sort: bool = False,
41024109
) -> None:
41034110
self.tensordict = tensordict
41044111
self.include_nested = include_nested
41054112
self.leaves_only = leaves_only
41064113
if is_leaf is None:
41074114
is_leaf = _default_is_leaf
41084115
self.is_leaf = is_leaf
4116+
self.sort = sort
41094117

41104118
def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
4111-
if not self.include_nested:
4112-
if self.leaves_only:
4113-
for key in self._keys():
4114-
target_class = self.tensordict.entry_class(key)
4115-
if _is_tensor_collection(target_class):
4116-
continue
4117-
yield key
4119+
def _iter():
4120+
if not self.include_nested:
4121+
if self.leaves_only:
4122+
for key in self._keys():
4123+
target_class = self.tensordict.entry_class(key)
4124+
if _is_tensor_collection(target_class):
4125+
continue
4126+
yield key
4127+
else:
4128+
yield from self._keys()
41184129
else:
4119-
yield from self._keys()
4120-
else:
4121-
yield from (
4122-
key if len(key) > 1 else key[0]
4123-
for key in self._iter_helper(self.tensordict)
4124-
)
4130+
yield from (
4131+
key if len(key) > 1 else key[0]
4132+
for key in self._iter_helper(self.tensordict)
4133+
)
41254134

4135+
if self.sort:
4136+
yield from sorted(_iter(), key=lambda key: ".".join(key) if isinstance(key, tuple) else key)
4137+
else:
4138+
yield from _iter()
41264139
def _iter_helper(
41274140
self, tensordict: T, prefix: str | None = None
41284141
) -> Iterable[str] | Iterable[tuple[str, ...]]:

0 commit comments

Comments
 (0)