You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: torchrl/envs/transforms/rlhf.py
+46-3
Original file line number
Diff line number
Diff line change
@@ -4,6 +4,7 @@
4
4
# LICENSE file in the root directory of this source tree.
5
5
from __future__ importannotations
6
6
7
+
fromcollectionsimportdeque
7
8
fromcollections.abcimportMapping
8
9
fromcopyimportcopy, deepcopy
9
10
fromtypingimportAny, Callable, Iterable, Literal
@@ -87,11 +88,21 @@ class DataLoadingPrimer(TensorDictPrimer):
87
88
88
89
Args:
89
90
dataloader (Iterable[Any]): The dataloader to load data from.
91
+
92
+
Keyword Args:
90
93
primers (Composite | None, optional): The primers to use for each key in the dataloader. Defaults to None.
91
94
data_keys (List[NestedKey] | None, optional): The keys to use for each item in the dataloader. Defaults to None.
92
95
data_specs (List[TensorSpec] | None, optional): The specs to use for each item in the dataloader. Defaults to None.
93
96
example_data (Any, optional): Example data to use for initializing the primer. Defaults to None.
94
97
stack_method (Callable[[Any], Any] | Literal["as_nested_tensor", "as_padded_tensor"], optional): The method to use for stacking the data. Defaults to ``maybe_dense_stack``.
98
+
use_buffer (bool, optional): Whether to use a buffer to load the batches. When an environment has a batch-size
99
+
that differs from the dataloader's, or when partial resets are to be expected, using a buffer to store data
100
+
ensures that `next()` is called on the dataloader only when necessary, and that elements of the dataset
101
+
are loaded in order.
102
+
Defaults to ``True`` whenever the batch-size of the dataloader is greater than 1.
103
+
auto_batch_size (bool, optional): If ``True`` (default if `dataloader.batch_size > 0`), the batch size of the
104
+
tensordict returned by the transform will be automatically determined assuming that there is a single batch
105
+
dimension.
95
106
96
107
Attributes:
97
108
dataloader (Iterable[Any]): The dataloader to load data from.
@@ -339,14 +350,25 @@ class DataLoadingPrimer(TensorDictPrimer):
0 commit comments