@@ -3144,12 +3144,23 @@ def keys(
3144
3144
include_nested : bool = False ,
3145
3145
leaves_only : bool = False ,
3146
3146
is_leaf : Callable [[Type ], bool ] | None = None ,
3147
+ * ,
3148
+ sort : bool = False ,
3147
3149
) -> _TensorDictKeysView :
3148
3150
if not include_nested and not leaves_only and is_leaf is None :
3149
- return _StringKeys (self ._tensordict .keys ())
3151
+ if not sort :
3152
+ return _StringKeys (self ._tensordict .keys ())
3153
+ else :
3154
+ return sorted (
3155
+ _StringKeys (self ._tensordict .keys ()),
3156
+ key = lambda x : "." .join (x ) if isinstance (x , tuple ) else x ,
3157
+ )
3150
3158
else :
3151
3159
return self ._nested_keys (
3152
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3160
+ include_nested = include_nested ,
3161
+ leaves_only = leaves_only ,
3162
+ is_leaf = is_leaf ,
3163
+ sort = sort ,
3153
3164
)
3154
3165
3155
3166
@cache # noqa: B019
@@ -3158,12 +3169,15 @@ def _nested_keys(
3158
3169
include_nested : bool = False ,
3159
3170
leaves_only : bool = False ,
3160
3171
is_leaf : Callable [[Type ], bool ] | None = None ,
3172
+ * ,
3173
+ sort : bool = False ,
3161
3174
) -> _TensorDictKeysView :
3162
3175
return _TensorDictKeysView (
3163
3176
self ,
3164
3177
include_nested = include_nested ,
3165
3178
leaves_only = leaves_only ,
3166
3179
is_leaf = is_leaf ,
3180
+ sort = sort ,
3167
3181
)
3168
3182
3169
3183
# some custom methods for efficiency
@@ -3172,81 +3186,68 @@ def items(
3172
3186
include_nested : bool = False ,
3173
3187
leaves_only : bool = False ,
3174
3188
is_leaf : Callable [[Type ], bool ] | None = None ,
3189
+ * ,
3190
+ sort : bool = False ,
3175
3191
) -> Iterator [tuple [str , CompatibleType ]]:
3176
3192
if not include_nested and not leaves_only :
3177
- return self ._tensordict .items ()
3178
- elif include_nested and leaves_only :
3193
+ if not sort :
3194
+ return self ._tensordict .items ()
3195
+ return sorted (self ._tensordict .items (), key = lambda x : x [0 ])
3196
+ elif include_nested and leaves_only and not sort :
3179
3197
is_leaf = _default_is_leaf if is_leaf is None else is_leaf
3180
3198
result = []
3181
- if is_dynamo_compiling ():
3182
-
3183
- def fast_iter ():
3184
- for key , val in self ._tensordict .items ():
3185
- if not is_leaf (type (val )):
3186
- for _key , _val in val .items (
3187
- include_nested = include_nested ,
3188
- leaves_only = leaves_only ,
3189
- is_leaf = is_leaf ,
3190
- ):
3191
- result .append (
3192
- (
3193
- (
3194
- key ,
3195
- * (
3196
- (_key ,)
3197
- if isinstance (_key , str )
3198
- else _key
3199
- ),
3200
- ),
3201
- _val ,
3202
- )
3203
- )
3204
- else :
3205
- result .append ((key , val ))
3206
- return result
3207
3199
3208
- else :
3209
- # dynamo doesn't like generators
3210
- def fast_iter ():
3211
- for key , val in self ._tensordict .items ():
3212
- if not is_leaf (type (val )):
3213
- yield from (
3214
- (
3215
- (
3216
- key ,
3217
- * ((_key ,) if isinstance (_key , str ) else _key ),
3218
- ),
3219
- _val ,
3220
- )
3221
- for _key , _val in val .items (
3222
- include_nested = include_nested ,
3223
- leaves_only = leaves_only ,
3224
- is_leaf = is_leaf ,
3225
- )
3226
- )
3227
- else :
3228
- yield (key , val )
3200
+ def fast_iter ():
3201
+ for key , val in self ._tensordict .items ():
3202
+ # We could easily make this faster, here we're iterating twice over the keys,
3203
+ # but we could iterate just once.
3204
+ # Ideally we should make a "dirty" list of items then call unravel_key on all of them.
3205
+ if not is_leaf (type (val )):
3206
+ for _key , _val in val .items (
3207
+ include_nested = include_nested ,
3208
+ leaves_only = leaves_only ,
3209
+ is_leaf = is_leaf ,
3210
+ ):
3211
+ if isinstance (_key , str ):
3212
+ _key = (key , _key )
3213
+ else :
3214
+ _key = (key , * _key )
3215
+ result .append ((_key , _val ))
3216
+ else :
3217
+ result .append ((key , val ))
3218
+ return result
3229
3219
3230
3220
return fast_iter ()
3231
3221
else :
3232
3222
return super ().items (
3233
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3223
+ include_nested = include_nested ,
3224
+ leaves_only = leaves_only ,
3225
+ is_leaf = is_leaf ,
3226
+ sort = sort ,
3234
3227
)
3235
3228
3236
3229
def values (
3237
3230
self ,
3238
3231
include_nested : bool = False ,
3239
3232
leaves_only : bool = False ,
3240
3233
is_leaf : Callable [[Type ], bool ] | None = None ,
3234
+ * ,
3235
+ sort : bool = False ,
3241
3236
) -> Iterator [tuple [str , CompatibleType ]]:
3242
3237
if not include_nested and not leaves_only :
3243
- return self ._tensordict .values ()
3238
+ if not sort :
3239
+ return self ._tensordict .values ()
3240
+ else :
3241
+ return list (zip (* sorted (self ._tensordict .items (), key = lambda x : x [0 ])))[
3242
+ 1
3243
+ ]
3244
3244
else :
3245
3245
return TensorDictBase .values (
3246
3246
self ,
3247
3247
include_nested = include_nested ,
3248
3248
leaves_only = leaves_only ,
3249
3249
is_leaf = is_leaf ,
3250
+ sort = sort ,
3250
3251
)
3251
3252
3252
3253
@@ -3535,9 +3536,14 @@ def keys(
3535
3536
include_nested : bool = False ,
3536
3537
leaves_only : bool = False ,
3537
3538
is_leaf : Callable [[Type ], bool ] | None = None ,
3539
+ * ,
3540
+ sort : bool = False ,
3538
3541
) -> _TensorDictKeysView :
3539
3542
return self ._source .keys (
3540
- include_nested = include_nested , leaves_only = leaves_only , is_leaf = is_leaf
3543
+ include_nested = include_nested ,
3544
+ leaves_only = leaves_only ,
3545
+ is_leaf = is_leaf ,
3546
+ sort = sort ,
3541
3547
)
3542
3548
3543
3549
def entry_class (self , key : NestedKey ) -> type :
@@ -4172,29 +4178,40 @@ def __init__(
4172
4178
include_nested : bool ,
4173
4179
leaves_only : bool ,
4174
4180
is_leaf : Callable [[Type ], bool ] = None ,
4181
+ sort : bool = False ,
4175
4182
) -> None :
4176
4183
self .tensordict = tensordict
4177
4184
self .include_nested = include_nested
4178
4185
self .leaves_only = leaves_only
4179
4186
if is_leaf is None :
4180
4187
is_leaf = _default_is_leaf
4181
4188
self .is_leaf = is_leaf
4189
+ self .sort = sort
4182
4190
4183
4191
def __iter__ (self ) -> Iterable [str ] | Iterable [tuple [str , ...]]:
4184
- if not self .include_nested :
4185
- if self .leaves_only :
4186
- for key in self ._keys ():
4187
- target_class = self .tensordict .entry_class (key )
4188
- if _is_tensor_collection (target_class ):
4189
- continue
4190
- yield key
4192
+ def _iter ():
4193
+ if not self .include_nested :
4194
+ if self .leaves_only :
4195
+ for key in self ._keys ():
4196
+ target_class = self .tensordict .entry_class (key )
4197
+ if _is_tensor_collection (target_class ):
4198
+ continue
4199
+ yield key
4200
+ else :
4201
+ yield from self ._keys ()
4191
4202
else :
4192
- yield from self ._keys ()
4193
- else :
4194
- yield from (
4195
- key if len (key ) > 1 else key [0 ]
4196
- for key in self ._iter_helper (self .tensordict )
4203
+ yield from (
4204
+ key if len (key ) > 1 else key [0 ]
4205
+ for key in self ._iter_helper (self .tensordict )
4206
+ )
4207
+
4208
+ if self .sort :
4209
+ yield from sorted (
4210
+ _iter (),
4211
+ key = lambda key : "." .join (key ) if isinstance (key , tuple ) else key ,
4197
4212
)
4213
+ else :
4214
+ yield from _iter ()
4198
4215
4199
4216
def _iter_helper (
4200
4217
self , tensordict : T , prefix : str | None = None
0 commit comments