Skip to content

Commit 26ca66e

Browse files
committed
[Feature] Add MCTSForest/Tree.to_string
ghstack-source-id: ef146b5f851ce47495cb729c3b675170698856ff Pull Request resolved: pytorch#2794
1 parent 76aa9bc commit 26ca66e

File tree

2 files changed

+128
-0
lines changed

2 files changed

+128
-0
lines changed

test/test_storage_map.py

+71
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,77 @@ def test_forest_check_obs_match(self, intersect):
684684
).all()
685685
prev_tree = subtree
686686

687+
def test_to_string(self):
688+
forest = MCTSForest()
689+
690+
td_root = TensorDict(
691+
{
692+
"observation": 0,
693+
}
694+
)
695+
696+
rollouts_data = [
697+
# [(action, obs), ...]
698+
[(3, 123), (1, 456)],
699+
[(2, 359), (2, 3094)],
700+
[(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
701+
[(1, 75)],
702+
[(3, 123), (0, 948)],
703+
[(2, 359), (2, 3094), (10, 68)],
704+
[(2, 359), (2, 3094), (11, 9045)],
705+
]
706+
707+
obs_string_check = "\n".join(
708+
[
709+
"(0,) [123]",
710+
" (0, 0) [456]",
711+
" (0, 1) [392, 989, 809, 847]",
712+
" (0, 2) [948]",
713+
"(1,) [359, 3094]",
714+
" (1, 0) [68]",
715+
" (1, 1) [9045]",
716+
"(2,) [75]",
717+
]
718+
)
719+
720+
action_string_check = "\n".join(
721+
[
722+
"(0,) [3]",
723+
" (0, 0) [1]",
724+
" (0, 1) [9, 6, 20, 21]",
725+
" (0, 2) [0]",
726+
"(1,) [2, 2]",
727+
" (1, 0) [10]",
728+
" (1, 1) [11]",
729+
"(2,) [1]",
730+
]
731+
)
732+
733+
for rollout_data in rollouts_data:
734+
td = td_root.clone().unsqueeze(0)
735+
for action, obs in rollout_data:
736+
td = td.update(
737+
TensorDict(
738+
{
739+
"action": [action],
740+
"next": TensorDict({"observation": [obs]}, [1]),
741+
},
742+
[1],
743+
)
744+
)
745+
forest.extend(td)
746+
td = td["next"].clone()
747+
748+
obs_string = forest.to_string(
749+
td_root, lambda tree: tree.rollout["next", "observation"].tolist()
750+
)
751+
assert obs_string == obs_string_check
752+
753+
action_string = forest.to_string(
754+
td_root, lambda tree: tree.rollout["action"].tolist()
755+
)
756+
assert action_string == action_string_check
757+
687758

688759
if __name__ == "__main__":
689760
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/tree.py

+57
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,42 @@ def plot(
604604
f"Unknown plotting backend {backend} with figure {figure}."
605605
)
606606

607+
def to_string(self, node_format_fn):
608+
"""Generates a string representation of the tree.
609+
610+
This function can pull out information from each of the nodes in a tree,
611+
so it can be useful for debugging. The nodes are listed line-by-line.
612+
Each line contains the path to the node, followed by the string
613+
representation of that node generated with :arg:`node_format_fn`. Each
614+
line is indented according to number of steps in the path required to
615+
get to the corresponding node.
616+
617+
Args:
618+
node_format_fn (Callable): User-defined function to generate a
619+
string for each node of the tree. The signature must be
620+
``(Tree) -> Any``, and the output must be convertible to
621+
a string.
622+
"""
623+
queue = [
624+
# tree, path
625+
(self, ()),
626+
]
627+
628+
strings = []
629+
630+
while len(queue) > 0:
631+
self, path = queue.pop()
632+
if self.subtree is not None:
633+
for subtree_idx, subtree in reversed(list(enumerate(self.subtree))):
634+
queue.append((subtree, path + (subtree_idx,)))
635+
636+
if self.rollout is not None:
637+
level = len(path)
638+
string = node_format_fn(self)
639+
strings.append(f"{' ' * (level - 1)}{path} {string}")
640+
641+
return "\n".join(strings)
642+
607643

608644
class MCTSForest:
609645
"""A collection of MCTS trees.
@@ -1164,6 +1200,27 @@ def valid_paths(cls, tree: Tree):
11641200
def __len__(self):
11651201
return len(self.data_map)
11661202

1203+
def to_string(self, td_root, node_format_fn):
1204+
"""Generates a string representation of a tree in the forest.
1205+
1206+
This function can pull out information from each of the nodes in a tree,
1207+
so it can be useful for debugging. The nodes are listed line-by-line.
1208+
Each line contains the path to the node, followed by the string
1209+
representation of that node generated with :arg:`node_format_fn`. Each
1210+
line is indented according to number of steps in the path required to
1211+
get to the corresponding node.
1212+
1213+
Args:
1214+
td_root (TensorDict): Root of the tree.
1215+
1216+
node_format_fn (Callable): User-defined function to generate a
1217+
string for each node of the tree. The signature must be
1218+
``(Tree) -> Any``, and the output must be convertible to
1219+
a string.
1220+
"""
1221+
tree = self.get_tree(td_root)
1222+
return tree.to_string(node_format_fn)
1223+
11671224

11681225
def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]:
11691226
if obj is None:

0 commit comments

Comments
 (0)