Skip to content

Commit 744ee17

Browse files
committed
[Doc] Add docstring for MCTSForest.extend
ghstack-source-id: dbef5e48ea55db6ba7867e1b24eb4711ad08af61 Pull Request resolved: #2795
1 parent 2865bbf commit 744ee17

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

torchrl/data/map/tree.py

+27
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,33 @@ def _make_node_map(self, source, dest):
10031003
self.max_size = self.data_map.max_size
10041004

10051005
def extend(self, rollout, *, return_node: bool = False):
1006+
"""Add a rollout to the forest.
1007+
1008+
Nodes are only added to a tree at points where rollouts diverge from
1009+
each other and at the endpoints of rollouts.
1010+
1011+
If there is no existing tree that matches the first steps of the
1012+
rollout, a new tree is added. Only one node is created, for the final
1013+
step.
1014+
1015+
If there is an existing tree that matches, the rollout is added to that
1016+
tree. If the rollout diverges from all other rollouts in the tree at
1017+
some step, a new node is created before the step where the rollouts
1018+
diverge, and a leaf node is created for the final step of the rollout.
1019+
If all of the rollout's steps match with a previously added rollout,
1020+
nothing changes. If the rollout matches up to a leaf node of a tree but
1021+
continues beyond it, that node is extended to the end of the rollout,
1022+
and no new nodes are created.
1023+
1024+
Args:
1025+
rollout (TensorDict): The rollout to add to the forest.
1026+
return_node (bool, optional): If True, the method returns the added
1027+
node. Default is ``False``.
1028+
1029+
Returns:
1030+
Tree: The node that was added to the forest. This is only
1031+
returned if ``return_node`` is True.
1032+
"""
10061033
source, dest = (
10071034
rollout.exclude("next").copy(),
10081035
rollout.select("next", *self.action_keys).copy(),

0 commit comments

Comments
 (0)