@@ -604,6 +604,79 @@ def plot(
604
604
f"Unknown plotting backend { backend } with figure { figure } ."
605
605
)
606
606
607
+ def to_string (self , node_format_fn = lambda tree : tree .node_data .to_dict ()):
608
+ """Generates a string representation of the tree.
609
+
610
+ This function can pull out information from each of the nodes in a tree,
611
+ so it can be useful for debugging. The nodes are listed line-by-line.
612
+ Each line contains the path to the node, followed by the string
613
+ representation of that node generated with :arg:`node_format_fn`. Each
614
+ line is indented according to number of steps in the path required to
615
+ get to the corresponding node.
616
+
617
+ Args:
618
+ node_format_fn (Callable, optional): User-defined function to
619
+ generate a string for each node of the tree. The signature must
620
+ be ``(Tree) -> Any``, and the output must be convertible to a
621
+ string. If this argument is not given, the generated string is
622
+ the node's :attr:`Tree.node_data` attribute converted to a dict.
623
+
624
+ Examples:
625
+ >>> from torchrl.data import MCTSForest
626
+ >>> from tensordict import TensorDict
627
+ >>> forest = MCTSForest()
628
+ >>> td_root = TensorDict({"observation": 0,})
629
+ >>> rollouts_data = [
630
+ ... # [(action, obs), ...]
631
+ ... [(3, 123), (1, 456)],
632
+ ... [(2, 359), (2, 3094)],
633
+ ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
634
+ ... [(1, 75)],
635
+ ... [(3, 123), (0, 948)],
636
+ ... [(2, 359), (2, 3094), (10, 68)],
637
+ ... [(2, 359), (2, 3094), (11, 9045)],
638
+ ... ]
639
+ >>> for rollout_data in rollouts_data:
640
+ ... td = td_root.clone().unsqueeze(0)
641
+ ... for action, obs in rollout_data:
642
+ ... td = td.update(TensorDict({
643
+ ... "action": [action],
644
+ ... "next": TensorDict({"observation": [obs]}, [1]),
645
+ ... }, [1]))
646
+ ... forest.extend(td)
647
+ ... td = td["next"].clone()
648
+ ...
649
+ >>> tree = forest.get_tree(td_root)
650
+ >>> print(tree.to_string())
651
+ (0,) {'observation': tensor(123)}
652
+ (0, 0) {'observation': tensor(456)}
653
+ (0, 1) {'observation': tensor(847)}
654
+ (0, 2) {'observation': tensor(948)}
655
+ (1,) {'observation': tensor(3094)}
656
+ (1, 0) {'observation': tensor(68)}
657
+ (1, 1) {'observation': tensor(9045)}
658
+ (2,) {'observation': tensor(75)}
659
+ """
660
+ queue = [
661
+ # tree, path
662
+ (self , ()),
663
+ ]
664
+
665
+ strings = []
666
+
667
+ while len (queue ) > 0 :
668
+ self , path = queue .pop ()
669
+ if self .subtree is not None :
670
+ for subtree_idx , subtree in reversed (list (enumerate (self .subtree ))):
671
+ queue .append ((subtree , path + (subtree_idx ,)))
672
+
673
+ if self .rollout is not None :
674
+ level = len (path )
675
+ string = node_format_fn (self )
676
+ strings .append (f"{ ' ' * (level - 1 )} { path } { string } " )
677
+
678
+ return "\n " .join (strings )
679
+
607
680
608
681
class MCTSForest :
609
682
"""A collection of MCTS trees.
@@ -1164,6 +1237,63 @@ def valid_paths(cls, tree: Tree):
1164
1237
def __len__ (self ):
1165
1238
return len (self .data_map )
1166
1239
1240
+ def to_string (self , td_root , node_format_fn = lambda tree : tree .node_data .to_dict ()):
1241
+ """Generates a string representation of a tree in the forest.
1242
+
1243
+ This function can pull out information from each of the nodes in a tree,
1244
+ so it can be useful for debugging. The nodes are listed line-by-line.
1245
+ Each line contains the path to the node, followed by the string
1246
+ representation of that node generated with :arg:`node_format_fn`. Each
1247
+ line is indented according to number of steps in the path required to
1248
+ get to the corresponding node.
1249
+
1250
+ Args:
1251
+ td_root (TensorDict): Root of the tree.
1252
+
1253
+ node_format_fn (Callable, optional): User-defined function to
1254
+ generate a string for each node of the tree. The signature must
1255
+ be ``(Tree) -> Any``, and the output must be convertible to a
1256
+ string. If this argument is not given, the generated string is
1257
+ the node's :attr:`Tree.node_data` attribute converted to a dict.
1258
+
1259
+ Examples:
1260
+ >>> from torchrl.data import MCTSForest
1261
+ >>> from tensordict import TensorDict
1262
+ >>> forest = MCTSForest()
1263
+ >>> td_root = TensorDict({"observation": 0,})
1264
+ >>> rollouts_data = [
1265
+ ... # [(action, obs), ...]
1266
+ ... [(3, 123), (1, 456)],
1267
+ ... [(2, 359), (2, 3094)],
1268
+ ... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
1269
+ ... [(1, 75)],
1270
+ ... [(3, 123), (0, 948)],
1271
+ ... [(2, 359), (2, 3094), (10, 68)],
1272
+ ... [(2, 359), (2, 3094), (11, 9045)],
1273
+ ... ]
1274
+ >>> for rollout_data in rollouts_data:
1275
+ ... td = td_root.clone().unsqueeze(0)
1276
+ ... for action, obs in rollout_data:
1277
+ ... td = td.update(TensorDict({
1278
+ ... "action": [action],
1279
+ ... "next": TensorDict({"observation": [obs]}, [1]),
1280
+ ... }, [1]))
1281
+ ... forest.extend(td)
1282
+ ... td = td["next"].clone()
1283
+ ...
1284
+ >>> print(forest.to_string(td_root))
1285
+ (0,) {'observation': tensor(123)}
1286
+ (0, 0) {'observation': tensor(456)}
1287
+ (0, 1) {'observation': tensor(847)}
1288
+ (0, 2) {'observation': tensor(948)}
1289
+ (1,) {'observation': tensor(3094)}
1290
+ (1, 0) {'observation': tensor(68)}
1291
+ (1, 1) {'observation': tensor(9045)}
1292
+ (2,) {'observation': tensor(75)}
1293
+ """
1294
+ tree = self .get_tree (td_root )
1295
+ return tree .to_string (node_format_fn )
1296
+
1167
1297
1168
1298
def _make_list_of_nestedkeys (obj : Any , attr : str ) -> List [NestedKey ]:
1169
1299
if obj is None :
0 commit comments