Skip to content

Commit 6063130

Browse files
committed
[Feature] Add MCTSForest/Tree.to_string
ghstack-source-id: ef146b5f851ce47495cb729c3b675170698856ff Pull Request resolved: pytorch#2794
1 parent 76aa9bc commit 6063130

File tree

2 files changed

+217
-0
lines changed

2 files changed

+217
-0
lines changed

test/test_storage_map.py

+87
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,93 @@ def test_forest_check_obs_match(self, intersect):
684684
).all()
685685
prev_tree = subtree
686686

687+
def test_to_string(self):
688+
forest = MCTSForest()
689+
690+
td_root = TensorDict(
691+
{
692+
"observation": 0,
693+
}
694+
)
695+
696+
rollouts_data = [
697+
# [(action, obs), ...]
698+
[(3, 123), (1, 456)],
699+
[(2, 359), (2, 3094)],
700+
[(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
701+
[(1, 75)],
702+
[(3, 123), (0, 948)],
703+
[(2, 359), (2, 3094), (10, 68)],
704+
[(2, 359), (2, 3094), (11, 9045)],
705+
]
706+
707+
default_string_check = "\n".join(
708+
[
709+
"(0,) {'observation': tensor(123)}",
710+
" (0, 0) {'observation': tensor(456)}",
711+
" (0, 1) {'observation': tensor(847)}",
712+
" (0, 2) {'observation': tensor(948)}",
713+
"(1,) {'observation': tensor(3094)}",
714+
" (1, 0) {'observation': tensor(68)}",
715+
" (1, 1) {'observation': tensor(9045)}",
716+
"(2,) {'observation': tensor(75)}",
717+
]
718+
)
719+
720+
obs_string_check = "\n".join(
721+
[
722+
"(0,) [123]",
723+
" (0, 0) [456]",
724+
" (0, 1) [392, 989, 809, 847]",
725+
" (0, 2) [948]",
726+
"(1,) [359, 3094]",
727+
" (1, 0) [68]",
728+
" (1, 1) [9045]",
729+
"(2,) [75]",
730+
]
731+
)
732+
733+
action_string_check = "\n".join(
734+
[
735+
"(0,) [3]",
736+
" (0, 0) [1]",
737+
" (0, 1) [9, 6, 20, 21]",
738+
" (0, 2) [0]",
739+
"(1,) [2, 2]",
740+
" (1, 0) [10]",
741+
" (1, 1) [11]",
742+
"(2,) [1]",
743+
]
744+
)
745+
746+
for rollout_data in rollouts_data:
747+
td = td_root.clone().unsqueeze(0)
748+
for action, obs in rollout_data:
749+
td = td.update(
750+
TensorDict(
751+
{
752+
"action": [action],
753+
"next": TensorDict({"observation": [obs]}, [1]),
754+
},
755+
[1],
756+
)
757+
)
758+
forest.extend(td)
759+
td = td["next"].clone()
760+
761+
default_string = forest.to_string(td_root)
762+
assert default_string == default_string_check
763+
764+
obs_string = forest.to_string(
765+
td_root, lambda tree: tree.rollout["next", "observation"].tolist()
766+
)
767+
assert obs_string == obs_string_check
768+
769+
action_string = forest.to_string(
770+
td_root, lambda tree: tree.rollout["action"].tolist()
771+
)
772+
assert action_string == action_string_check
773+
687774

688775
if __name__ == "__main__":
689776
args, unknown = argparse.ArgumentParser().parse_known_args()

torchrl/data/map/tree.py

+130
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,79 @@ def plot(
604604
f"Unknown plotting backend {backend} with figure {figure}."
605605
)
606606

607+
def to_string(self, node_format_fn=lambda tree: tree.node_data.to_dict()):
608+
"""Generates a string representation of the tree.
609+
610+
This function can pull out information from each of the nodes in a tree,
611+
so it can be useful for debugging. The nodes are listed line-by-line.
612+
Each line contains the path to the node, followed by the string
613+
representation of that node generated with :arg:`node_format_fn`. Each
614+
line is indented according to number of steps in the path required to
615+
get to the corresponding node.
616+
617+
Args:
618+
node_format_fn (Callable, optional): User-defined function to
619+
generate a string for each node of the tree. The signature must
620+
be ``(Tree) -> Any``, and the output must be convertible to a
621+
string. If this argument is not given, the generated string is
622+
the node's :attr:`Tree.node_data` attribute converted to a dict.
623+
624+
Examples:
625+
>>> from torchrl.data import MCTSForest
626+
>>> from tensordict import TensorDict
627+
>>> forest = MCTSForest()
628+
>>> td_root = TensorDict({"observation": 0,})
629+
>>> rollouts_data = [
630+
... # [(action, obs), ...]
631+
... [(3, 123), (1, 456)],
632+
... [(2, 359), (2, 3094)],
633+
... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
634+
... [(1, 75)],
635+
... [(3, 123), (0, 948)],
636+
... [(2, 359), (2, 3094), (10, 68)],
637+
... [(2, 359), (2, 3094), (11, 9045)],
638+
... ]
639+
>>> for rollout_data in rollouts_data:
640+
... td = td_root.clone().unsqueeze(0)
641+
... for action, obs in rollout_data:
642+
... td = td.update(TensorDict({
643+
... "action": [action],
644+
... "next": TensorDict({"observation": [obs]}, [1]),
645+
... }, [1]))
646+
... forest.extend(td)
647+
... td = td["next"].clone()
648+
...
649+
>>> tree = forest.get_tree(td_root)
650+
>>> print(tree.to_string())
651+
(0,) {'observation': tensor(123)}
652+
(0, 0) {'observation': tensor(456)}
653+
(0, 1) {'observation': tensor(847)}
654+
(0, 2) {'observation': tensor(948)}
655+
(1,) {'observation': tensor(3094)}
656+
(1, 0) {'observation': tensor(68)}
657+
(1, 1) {'observation': tensor(9045)}
658+
(2,) {'observation': tensor(75)}
659+
"""
660+
queue = [
661+
# tree, path
662+
(self, ()),
663+
]
664+
665+
strings = []
666+
667+
while len(queue) > 0:
668+
self, path = queue.pop()
669+
if self.subtree is not None:
670+
for subtree_idx, subtree in reversed(list(enumerate(self.subtree))):
671+
queue.append((subtree, path + (subtree_idx,)))
672+
673+
if self.rollout is not None:
674+
level = len(path)
675+
string = node_format_fn(self)
676+
strings.append(f"{' ' * (level - 1)}{path} {string}")
677+
678+
return "\n".join(strings)
679+
607680

608681
class MCTSForest:
609682
"""A collection of MCTS trees.
@@ -1164,6 +1237,63 @@ def valid_paths(cls, tree: Tree):
11641237
def __len__(self):
11651238
return len(self.data_map)
11661239

1240+
def to_string(self, td_root, node_format_fn=lambda tree: tree.node_data.to_dict()):
1241+
"""Generates a string representation of a tree in the forest.
1242+
1243+
This function can pull out information from each of the nodes in a tree,
1244+
so it can be useful for debugging. The nodes are listed line-by-line.
1245+
Each line contains the path to the node, followed by the string
1246+
representation of that node generated with :arg:`node_format_fn`. Each
1247+
line is indented according to number of steps in the path required to
1248+
get to the corresponding node.
1249+
1250+
Args:
1251+
td_root (TensorDict): Root of the tree.
1252+
1253+
node_format_fn (Callable, optional): User-defined function to
1254+
generate a string for each node of the tree. The signature must
1255+
be ``(Tree) -> Any``, and the output must be convertible to a
1256+
string. If this argument is not given, the generated string is
1257+
the node's :attr:`Tree.node_data` attribute converted to a dict.
1258+
1259+
Examples:
1260+
>>> from torchrl.data import MCTSForest
1261+
>>> from tensordict import TensorDict
1262+
>>> forest = MCTSForest()
1263+
>>> td_root = TensorDict({"observation": 0,})
1264+
>>> rollouts_data = [
1265+
... # [(action, obs), ...]
1266+
... [(3, 123), (1, 456)],
1267+
... [(2, 359), (2, 3094)],
1268+
... [(3, 123), (9, 392), (6, 989), (20, 809), (21, 847)],
1269+
... [(1, 75)],
1270+
... [(3, 123), (0, 948)],
1271+
... [(2, 359), (2, 3094), (10, 68)],
1272+
... [(2, 359), (2, 3094), (11, 9045)],
1273+
... ]
1274+
>>> for rollout_data in rollouts_data:
1275+
... td = td_root.clone().unsqueeze(0)
1276+
... for action, obs in rollout_data:
1277+
... td = td.update(TensorDict({
1278+
... "action": [action],
1279+
... "next": TensorDict({"observation": [obs]}, [1]),
1280+
... }, [1]))
1281+
... forest.extend(td)
1282+
... td = td["next"].clone()
1283+
...
1284+
>>> print(forest.to_string(td_root))
1285+
(0,) {'observation': tensor(123)}
1286+
(0, 0) {'observation': tensor(456)}
1287+
(0, 1) {'observation': tensor(847)}
1288+
(0, 2) {'observation': tensor(948)}
1289+
(1,) {'observation': tensor(3094)}
1290+
(1, 0) {'observation': tensor(68)}
1291+
(1, 1) {'observation': tensor(9045)}
1292+
(2,) {'observation': tensor(75)}
1293+
"""
1294+
tree = self.get_tree(td_root)
1295+
return tree.to_string(node_format_fn)
1296+
11671297

11681298
def _make_list_of_nestedkeys(obj: Any, attr: str) -> List[NestedKey]:
11691299
if obj is None:

0 commit comments

Comments
 (0)