@@ -604,6 +604,42 @@ def plot(
604
604
f"Unknown plotting backend { backend } with figure { figure } ."
605
605
)
606
606
607
+ def to_string (self , node_format_fn ):
608
+ """Generates a string representation of the tree.
609
+
610
+ This function can pull out information from each of the nodes in a tree,
611
+ so it can be useful for debugging. The nodes are listed line-by-line.
612
+ Each line contains the path to the node, followed by the string
613
+ representation of that node generated with :arg:`node_format_fn`. Each
614
+ line is indented according to number of steps in the path required to
615
+ get to the corresponding node.
616
+
617
+ Args:
618
+ node_format_fn (Callable): User-defined function to generate a
619
+ string for each node of the tree. The signature must be
620
+ ``(Tree) -> Any``, and the output must be convertible to
621
+ a string.
622
+ """
623
+ queue = [
624
+ # tree, path
625
+ (self , ()),
626
+ ]
627
+
628
+ strings = []
629
+
630
+ while len (queue ) > 0 :
631
+ self , path = queue .pop ()
632
+ if self .subtree is not None :
633
+ for subtree_idx , subtree in reversed (list (enumerate (self .subtree ))):
634
+ queue .append ((subtree , path + (subtree_idx ,)))
635
+
636
+ if self .rollout is not None :
637
+ level = len (path )
638
+ string = node_format_fn (self )
639
+ strings .append (f"{ ' ' * (level - 1 )} { path } { string } " )
640
+
641
+ return "\n " .join (strings )
642
+
607
643
608
644
class MCTSForest :
609
645
"""A collection of MCTS trees.
@@ -1164,6 +1200,27 @@ def valid_paths(cls, tree: Tree):
1164
1200
def __len__ (self ):
1165
1201
return len (self .data_map )
1166
1202
1203
+ def to_string (self , td_root , node_format_fn ):
1204
+ """Generates a string representation of a tree in the forest.
1205
+
1206
+ This function can pull out information from each of the nodes in a tree,
1207
+ so it can be useful for debugging. The nodes are listed line-by-line.
1208
+ Each line contains the path to the node, followed by the string
1209
+ representation of that node generated with :arg:`node_format_fn`. Each
1210
+ line is indented according to number of steps in the path required to
1211
+ get to the corresponding node.
1212
+
1213
+ Args:
1214
+ td_root (TensorDict): Root of the tree.
1215
+
1216
+ node_format_fn (Callable): User-defined function to generate a
1217
+ string for each node of the tree. The signature must be
1218
+ ``(Tree) -> Any``, and the output must be convertible to
1219
+ a string.
1220
+ """
1221
+ tree = self .get_tree (td_root )
1222
+ return tree .to_string (node_format_fn )
1223
+
1167
1224
1168
1225
def _make_list_of_nestedkeys (obj : Any , attr : str ) -> List [NestedKey ]:
1169
1226
if obj is None :
0 commit comments