@@ -1003,6 +1003,33 @@ def _make_node_map(self, source, dest):
1003
1003
self .max_size = self .data_map .max_size
1004
1004
1005
1005
def extend (self , rollout , * , return_node : bool = False ):
1006
+ """Add a rollout to the forest.
1007
+
1008
+ Nodes are only added to a tree at points where rollouts diverge from
1009
+ each other and at the endpoints of rollouts.
1010
+
1011
+ If there is no existing tree that matches the first steps of the
1012
+ rollout, a new tree is added. Only one node is created, for the final
1013
+ step.
1014
+
1015
+ If there is an existing tree that matches, the rollout is added to that
1016
+ tree. If the rollout diverges from all other rollouts in the tree at
1017
+ some step, a new node is created before the step where the rollouts
1018
+ diverge, and a leaf node is created for the final step of the rollout.
1019
+ If all of the rollout's steps match with a previously added rollout,
1020
+ nothing changes. If the rollout matches up to a leaf node of a tree but
1021
+ continues beyond it, that node is extended to the end of the rollout,
1022
+ and no new nodes are created.
1023
+
1024
+ Args:
1025
+ rollout (TensorDict): The rollout to add to the forest.
1026
+ return_node (bool, optional): If True, the method returns the added
1027
+ node. Default is ``False``.
1028
+
1029
+ Returns:
1030
+ Tree: The node that was added to the forest. This is only
1031
+ returned if ``return_node`` is True.
1032
+ """
1006
1033
source , dest = (
1007
1034
rollout .exclude ("next" ).copy (),
1008
1035
rollout .select ("next" , * self .action_keys ).copy (),
0 commit comments