@@ -3074,12 +3074,23 @@ def keys(
3074
3074
include_nested : bool = False ,
3075
3075
leaves_only : bool = False ,
3076
3076
is_leaf : Callable [[Type ], bool ] | None = None ,
3077
+ * ,
3078
+ sort : bool = False ,
3077
3079
) -> _TensorDictKeysView :
3078
3080
if not include_nested and not leaves_only and is_leaf is None :
3079
- return _StringKeys (self ._tensordict .keys ())
3081
+ if not sort :
3082
+ return _StringKeys (self ._tensordict .keys ())
3083
+ else :
3084
+ return sorted (
3085
+ _StringKeys (self ._tensordict .keys ()),
3086
+ key = lambda x : "." .join (x ) if isinstance (x , tuple ) else x ,
3087
+ )
3080
3088
else :
3081
3089
return self ._nested_keys (
3082
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3090
+ include_nested = include_nested ,
3091
+ leaves_only = leaves_only ,
3092
+ is_leaf = is_leaf ,
3093
+ sort = sort ,
3083
3094
)
3084
3095
3085
3096
@cache # noqa: B019
@@ -3088,12 +3099,15 @@ def _nested_keys(
3088
3099
include_nested : bool = False ,
3089
3100
leaves_only : bool = False ,
3090
3101
is_leaf : Callable [[Type ], bool ] | None = None ,
3102
+ * ,
3103
+ sort : bool = False ,
3091
3104
) -> _TensorDictKeysView :
3092
3105
return _TensorDictKeysView (
3093
3106
self ,
3094
3107
include_nested = include_nested ,
3095
3108
leaves_only = leaves_only ,
3096
3109
is_leaf = is_leaf ,
3110
+ sort = sort ,
3097
3111
)
3098
3112
3099
3113
# some custom methods for efficiency
@@ -3102,81 +3116,68 @@ def items(
3102
3116
include_nested : bool = False ,
3103
3117
leaves_only : bool = False ,
3104
3118
is_leaf : Callable [[Type ], bool ] | None = None ,
3119
+ * ,
3120
+ sort : bool = False ,
3105
3121
) -> Iterator [tuple [str , CompatibleType ]]:
3106
3122
if not include_nested and not leaves_only :
3107
- return self ._tensordict .items ()
3108
- elif include_nested and leaves_only :
3123
+ if not sort :
3124
+ return self ._tensordict .items ()
3125
+ return sorted (self ._tensordict .items (), key = lambda x : x [0 ])
3126
+ elif include_nested and leaves_only and not sort :
3109
3127
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
3110
3128
result = []
3111
- if is_dynamo_compiling ():
3112
-
3113
- def fast_iter ():
3114
- for key , val in self ._tensordict .items ():
3115
- if not is_leaf (type (val )):
3116
- for _key , _val in val .items (
3117
- include_nested = include_nested ,
3118
- leaves_only = leaves_only ,
3119
- is_leaf = is_leaf ,
3120
- ):
3121
- result .append (
3122
- (
3123
- (
3124
- key ,
3125
- * (
3126
- (_key ,)
3127
- if isinstance (_key , str )
3128
- else _key
3129
- ),
3130
- ),
3131
- _val ,
3132
- )
3133
- )
3134
- else :
3135
- result .append ((key , val ))
3136
- return result
3137
3129
3138
- else :
3139
- # dynamo doesn't like generators
3140
- def fast_iter ():
3141
- for key , val in self ._tensordict .items ():
3142
- if not is_leaf (type (val )):
3143
- yield from (
3144
- (
3145
- (
3146
- key ,
3147
- * ((_key ,) if isinstance (_key , str ) else _key ),
3148
- ),
3149
- _val ,
3150
- )
3151
- for _key , _val in val .items (
3152
- include_nested = include_nested ,
3153
- leaves_only = leaves_only ,
3154
- is_leaf = is_leaf ,
3155
- )
3156
- )
3157
- else :
3158
- yield (key , val )
3130
+ def fast_iter ():
3131
+ for key , val in self ._tensordict .items ():
3132
+ # We could easily make this faster, here we're iterating twice over the keys,
3133
+ # but we could iterate just once.
3134
+ # Ideally we should make a "dirty" list of items then call unravel_key on all of them.
3135
+ if not is_leaf (type (val )):
3136
+ for _key , _val in val .items (
3137
+ include_nested = include_nested ,
3138
+ leaves_only = leaves_only ,
3139
+ is_leaf = is_leaf ,
3140
+ ):
3141
+ if isinstance (_key , str ):
3142
+ _key = (key , _key )
3143
+ else :
3144
+ _key = (key , * _key )
3145
+ result .append ((_key , _val ))
3146
+ else :
3147
+ result .append ((key , val ))
3148
+ return result
3159
3149
3160
3150
return fast_iter ()
3161
3151
else :
3162
3152
return super ().items (
3163
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3153
+ include_nested = include_nested ,
3154
+ leaves_only = leaves_only ,
3155
+ is_leaf = is_leaf ,
3156
+ sort = sort ,
3164
3157
)
3165
3158
3166
3159
def values (
3167
3160
self ,
3168
3161
include_nested : bool = False ,
3169
3162
leaves_only : bool = False ,
3170
3163
is_leaf : Callable [[Type ], bool ] | None = None ,
3164
+ * ,
3165
+ sort : bool = False ,
3171
3166
) -> Iterator [tuple [str , CompatibleType ]]:
3172
3167
if not include_nested and not leaves_only :
3173
- return self ._tensordict .values ()
3168
+ if not sort :
3169
+ return self ._tensordict .values ()
3170
+ else :
3171
+ return list (zip (* sorted (self ._tensordict .items (), key = lambda x : x [0 ])))[
3172
+ 1
3173
+ ]
3174
3174
else :
3175
3175
return TensorDictBase .values (
3176
3176
self ,
3177
3177
include_nested = include_nested ,
3178
3178
leaves_only = leaves_only ,
3179
3179
is_leaf = is_leaf ,
3180
+ sort = sort ,
3180
3181
)
3181
3182
3182
3183
@@ -3465,9 +3466,14 @@ def keys(
3465
3466
include_nested : bool = False ,
3466
3467
leaves_only : bool = False ,
3467
3468
is_leaf : Callable [[Type ], bool ] | None = None ,
3469
+ * ,
3470
+ sort : bool = False ,
3468
3471
) -> _TensorDictKeysView :
3469
3472
return self ._source .keys (
3470
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3473
+ include_nested = include_nested ,
3474
+ leaves_only = leaves_only ,
3475
+ is_leaf = is_leaf ,
3476
+ sort = sort ,
3471
3477
)
3472
3478
3473
3479
def entry_class (self , key : NestedKey ) -> type :
@@ -4099,30 +4105,37 @@ def __init__(
4099
4105
include_nested : bool ,
4100
4106
leaves_only : bool ,
4101
4107
is_leaf : Callable [[Type ], bool ] = None ,
4108
+ sort : bool = False ,
4102
4109
) -> None :
4103
4110
self .tensordict = tensordict
4104
4111
self .include_nested = include_nested
4105
4112
self .leaves_only = leaves_only
4106
4113
if is_leaf is None :
4107
4114
is_leaf = _default_is_leaf
4108
4115
self .is_leaf = is_leaf
4116
+ self .sort = sort
4109
4117
4110
4118
def __iter__ (self ) -> Iterable [str ] | Iterable [tuple [str , ...]]:
4111
- if not self .include_nested :
4112
- if self .leaves_only :
4113
- for key in self ._keys ():
4114
- target_class = self .tensordict .entry_class (key )
4115
- if _is_tensor_collection (target_class ):
4116
- continue
4117
- yield key
4119
+ def _iter ():
4120
+ if not self .include_nested :
4121
+ if self .leaves_only :
4122
+ for key in self ._keys ():
4123
+ target_class = self .tensordict .entry_class (key )
4124
+ if _is_tensor_collection (target_class ):
4125
+ continue
4126
+ yield key
4127
+ else :
4128
+ yield from self ._keys ()
4118
4129
else :
4119
- yield from self ._keys ()
4120
- else :
4121
- yield from (
4122
- key if len (key ) > 1 else key [0 ]
4123
- for key in self ._iter_helper (self .tensordict )
4124
- )
4130
+ yield from (
4131
+ key if len (key ) > 1 else key [0 ]
4132
+ for key in self ._iter_helper (self .tensordict )
4133
+ )
4125
4134
4135
+ if self .sort :
4136
+ yield from sorted (_iter (), key = lambda key : "." .join (key ) if isinstance (key , tuple ) else key )
4137
+ else :
4138
+ yield from _iter ()
4126
4139
def _iter_helper (
4127
4140
self , tensordict : T , prefix : str | None = None
4128
4141
) -> Iterable [str ] | Iterable [tuple [str , ...]]:
0 commit comments