Skip to content

Commit e5c54aa

Browse files
author
Vincent Moens
committed
[Feature] sorted keys, values and items
ghstack-source-id: 79d007d Pull Request resolved: #965
1 parent fc323e5 commit e5c54aa

File tree

6 files changed

+271
-137
lines changed

6 files changed

+271
-137
lines changed

tensordict/_lazy.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1550,45 +1550,70 @@ def keys(
15501550
include_nested: bool = False,
15511551
leaves_only: bool = False,
15521552
is_leaf: Callable[[Type], bool] | None = None,
1553+
*,
1554+
sort: bool = False,
15531555
) -> _LazyStackedTensorDictKeysView:
15541556
keys = _LazyStackedTensorDictKeysView(
15551557
self,
15561558
include_nested=include_nested,
15571559
leaves_only=leaves_only,
15581560
is_leaf=is_leaf,
1561+
sort=sort,
15591562
)
15601563
return keys
15611564

1562-
def values(self, include_nested=False, leaves_only=False, is_leaf=None):
1565+
def values(
1566+
self,
1567+
include_nested=False,
1568+
leaves_only=False,
1569+
is_leaf=None,
1570+
*,
1571+
sort: bool = False,
1572+
):
15631573
if is_leaf not in (
15641574
_NESTED_TENSORS_AS_LISTS,
15651575
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
15661576
):
15671577
yield from super().values(
1568-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
1578+
include_nested=include_nested,
1579+
leaves_only=leaves_only,
1580+
is_leaf=is_leaf,
1581+
sort=sort,
15691582
)
15701583
else:
15711584
for td in self.tensordicts:
15721585
yield from td.values(
15731586
include_nested=include_nested,
15741587
leaves_only=leaves_only,
15751588
is_leaf=is_leaf,
1589+
sort=sort,
15761590
)
15771591

1578-
def items(self, include_nested=False, leaves_only=False, is_leaf=None):
1592+
def items(
1593+
self,
1594+
include_nested=False,
1595+
leaves_only=False,
1596+
is_leaf=None,
1597+
*,
1598+
sort: bool = False,
1599+
):
15791600
if is_leaf not in (
15801601
_NESTED_TENSORS_AS_LISTS,
15811602
_NESTED_TENSORS_AS_LISTS_NONTENSOR,
15821603
):
15831604
yield from super().items(
1584-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
1605+
include_nested=include_nested,
1606+
leaves_only=leaves_only,
1607+
is_leaf=is_leaf,
1608+
sort=sort,
15851609
)
15861610
else:
15871611
for i, td in enumerate(self.tensordicts):
15881612
for key, val in td.items(
15891613
include_nested=include_nested,
15901614
leaves_only=leaves_only,
15911615
is_leaf=is_leaf,
1616+
sort=sort,
15921617
):
15931618
if isinstance(key, str):
15941619
key = (str(i), key)
@@ -3381,9 +3406,14 @@ def keys(
33813406
include_nested: bool = False,
33823407
leaves_only: bool = False,
33833408
is_leaf: Callable[[Type], bool] | None = None,
3409+
*,
3410+
sort: bool = False,
33843411
) -> _TensorDictKeysView:
33853412
return self._source.keys(
3386-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3413+
include_nested=include_nested,
3414+
leaves_only=leaves_only,
3415+
is_leaf=is_leaf,
3416+
sort=sort,
33873417
)
33883418

33893419
def _select(

tensordict/_td.py

Lines changed: 83 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -3144,12 +3144,23 @@ def keys(
31443144
include_nested: bool = False,
31453145
leaves_only: bool = False,
31463146
is_leaf: Callable[[Type], bool] | None = None,
3147+
*,
3148+
sort: bool = False,
31473149
) -> _TensorDictKeysView:
31483150
if not include_nested and not leaves_only and is_leaf is None:
3149-
return _StringKeys(self._tensordict.keys())
3151+
if not sort:
3152+
return _StringKeys(self._tensordict.keys())
3153+
else:
3154+
return sorted(
3155+
_StringKeys(self._tensordict.keys()),
3156+
key=lambda x: ".".join(x) if isinstance(x, tuple) else x,
3157+
)
31503158
else:
31513159
return self._nested_keys(
3152-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3160+
include_nested=include_nested,
3161+
leaves_only=leaves_only,
3162+
is_leaf=is_leaf,
3163+
sort=sort,
31533164
)
31543165

31553166
@cache # noqa: B019
@@ -3158,12 +3169,15 @@ def _nested_keys(
31583169
include_nested: bool = False,
31593170
leaves_only: bool = False,
31603171
is_leaf: Callable[[Type], bool] | None = None,
3172+
*,
3173+
sort: bool = False,
31613174
) -> _TensorDictKeysView:
31623175
return _TensorDictKeysView(
31633176
self,
31643177
include_nested=include_nested,
31653178
leaves_only=leaves_only,
31663179
is_leaf=is_leaf,
3180+
sort=sort,
31673181
)
31683182

31693183
# some custom methods for efficiency
@@ -3172,81 +3186,68 @@ def items(
31723186
include_nested: bool = False,
31733187
leaves_only: bool = False,
31743188
is_leaf: Callable[[Type], bool] | None = None,
3189+
*,
3190+
sort: bool = False,
31753191
) -> Iterator[tuple[str, CompatibleType]]:
31763192
if not include_nested and not leaves_only:
3177-
return self._tensordict.items()
3178-
elif include_nested and leaves_only:
3193+
if not sort:
3194+
return self._tensordict.items()
3195+
return sorted(self._tensordict.items(), key=lambda x: x[0])
3196+
elif include_nested and leaves_only and not sort:
31793197
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
31803198
result = []
3181-
if is_dynamo_compiling():
3182-
3183-
def fast_iter():
3184-
for key, val in self._tensordict.items():
3185-
if not is_leaf(type(val)):
3186-
for _key, _val in val.items(
3187-
include_nested=include_nested,
3188-
leaves_only=leaves_only,
3189-
is_leaf=is_leaf,
3190-
):
3191-
result.append(
3192-
(
3193-
(
3194-
key,
3195-
*(
3196-
(_key,)
3197-
if isinstance(_key, str)
3198-
else _key
3199-
),
3200-
),
3201-
_val,
3202-
)
3203-
)
3204-
else:
3205-
result.append((key, val))
3206-
return result
32073199

3208-
else:
3209-
# dynamo doesn't like generators
3210-
def fast_iter():
3211-
for key, val in self._tensordict.items():
3212-
if not is_leaf(type(val)):
3213-
yield from (
3214-
(
3215-
(
3216-
key,
3217-
*((_key,) if isinstance(_key, str) else _key),
3218-
),
3219-
_val,
3220-
)
3221-
for _key, _val in val.items(
3222-
include_nested=include_nested,
3223-
leaves_only=leaves_only,
3224-
is_leaf=is_leaf,
3225-
)
3226-
)
3227-
else:
3228-
yield (key, val)
3200+
def fast_iter():
3201+
for key, val in self._tensordict.items():
3202+
# We could easily make this faster, here we're iterating twice over the keys,
3203+
# but we could iterate just once.
3204+
# Ideally we should make a "dirty" list of items then call unravel_key on all of them.
3205+
if not is_leaf(type(val)):
3206+
for _key, _val in val.items(
3207+
include_nested=include_nested,
3208+
leaves_only=leaves_only,
3209+
is_leaf=is_leaf,
3210+
):
3211+
if isinstance(_key, str):
3212+
_key = (key, _key)
3213+
else:
3214+
_key = (key, *_key)
3215+
result.append((_key, _val))
3216+
else:
3217+
result.append((key, val))
3218+
return result
32293219

32303220
return fast_iter()
32313221
else:
32323222
return super().items(
3233-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3223+
include_nested=include_nested,
3224+
leaves_only=leaves_only,
3225+
is_leaf=is_leaf,
3226+
sort=sort,
32343227
)
32353228

32363229
def values(
32373230
self,
32383231
include_nested: bool = False,
32393232
leaves_only: bool = False,
32403233
is_leaf: Callable[[Type], bool] | None = None,
3234+
*,
3235+
sort: bool = False,
32413236
) -> Iterator[tuple[str, CompatibleType]]:
32423237
if not include_nested and not leaves_only:
3243-
return self._tensordict.values()
3238+
if not sort:
3239+
return self._tensordict.values()
3240+
else:
3241+
return list(zip(*sorted(self._tensordict.items(), key=lambda x: x[0])))[
3242+
1
3243+
]
32443244
else:
32453245
return TensorDictBase.values(
32463246
self,
32473247
include_nested=include_nested,
32483248
leaves_only=leaves_only,
32493249
is_leaf=is_leaf,
3250+
sort=sort,
32503251
)
32513252

32523253

@@ -3535,9 +3536,14 @@ def keys(
35353536
include_nested: bool = False,
35363537
leaves_only: bool = False,
35373538
is_leaf: Callable[[Type], bool] | None = None,
3539+
*,
3540+
sort: bool = False,
35383541
) -> _TensorDictKeysView:
35393542
return self._source.keys(
3540-
include_nested=include_nested, leaves_only=leaves_only, is_leaf=is_leaf
3543+
include_nested=include_nested,
3544+
leaves_only=leaves_only,
3545+
is_leaf=is_leaf,
3546+
sort=sort,
35413547
)
35423548

35433549
def entry_class(self, key: NestedKey) -> type:
@@ -4172,29 +4178,40 @@ def __init__(
41724178
include_nested: bool,
41734179
leaves_only: bool,
41744180
is_leaf: Callable[[Type], bool] = None,
4181+
sort: bool = False,
41754182
) -> None:
41764183
self.tensordict = tensordict
41774184
self.include_nested = include_nested
41784185
self.leaves_only = leaves_only
41794186
if is_leaf is None:
41804187
is_leaf = _default_is_leaf
41814188
self.is_leaf = is_leaf
4189+
self.sort = sort
41824190

41834191
def __iter__(self) -> Iterable[str] | Iterable[tuple[str, ...]]:
4184-
if not self.include_nested:
4185-
if self.leaves_only:
4186-
for key in self._keys():
4187-
target_class = self.tensordict.entry_class(key)
4188-
if _is_tensor_collection(target_class):
4189-
continue
4190-
yield key
4192+
def _iter():
4193+
if not self.include_nested:
4194+
if self.leaves_only:
4195+
for key in self._keys():
4196+
target_class = self.tensordict.entry_class(key)
4197+
if _is_tensor_collection(target_class):
4198+
continue
4199+
yield key
4200+
else:
4201+
yield from self._keys()
41914202
else:
4192-
yield from self._keys()
4193-
else:
4194-
yield from (
4195-
key if len(key) > 1 else key[0]
4196-
for key in self._iter_helper(self.tensordict)
4203+
yield from (
4204+
key if len(key) > 1 else key[0]
4205+
for key in self._iter_helper(self.tensordict)
4206+
)
4207+
4208+
if self.sort:
4209+
yield from sorted(
4210+
_iter(),
4211+
key=lambda key: ".".join(key) if isinstance(key, tuple) else key,
41974212
)
4213+
else:
4214+
yield from _iter()
41984215

41994216
def _iter_helper(
42004217
self, tensordict: T, prefix: str | None = None

0 commit comments

Comments
 (0)