|
11 | 11 | from gymnasium import spaces |
12 | 12 |
|
13 | 13 | from stable_baselines3.common.base_class import BaseAlgorithm |
14 | | -from stable_baselines3.common.buffers import DictReplayBuffer, ReplayBuffer |
| 14 | +from stable_baselines3.common.buffers import DictReplayBuffer, NStepReplayBuffer, ReplayBuffer |
15 | 15 | from stable_baselines3.common.callbacks import BaseCallback |
16 | 16 | from stable_baselines3.common.noise import ActionNoise, VectorizedActionNoise |
17 | 17 | from stable_baselines3.common.policies import BasePolicy |
@@ -51,6 +51,7 @@ class OffPolicyAlgorithm(BaseAlgorithm): |
51 | 51 | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer |
52 | 52 | at a cost of more complexity. |
53 | 53 | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 |
| 54 | + :param n_steps: When n_step > 1, uses n-step return (with the NStepReplayBuffer) when updating the Q-value network. |
54 | 55 | :param policy_kwargs: Additional arguments to be passed to the policy on creation |
55 | 56 | :param stats_window_size: Window size for the rollout logging, specifying the number of episodes to average |
56 | 57 | the reported success rate, mean episode length, and mean reward over |
@@ -93,6 +94,7 @@ def __init__( |
93 | 94 | replay_buffer_class: Optional[type[ReplayBuffer]] = None, |
94 | 95 | replay_buffer_kwargs: Optional[dict[str, Any]] = None, |
95 | 96 | optimize_memory_usage: bool = False, |
| 97 | + n_steps: int = 1, |
96 | 98 | policy_kwargs: Optional[dict[str, Any]] = None, |
97 | 99 | stats_window_size: int = 100, |
98 | 100 | tensorboard_log: Optional[str] = None, |
@@ -134,7 +136,7 @@ def __init__( |
134 | 136 | self.replay_buffer: Optional[ReplayBuffer] = None |
135 | 137 | self.replay_buffer_class = replay_buffer_class |
136 | 138 | self.replay_buffer_kwargs = replay_buffer_kwargs or {} |
137 | | - self._episode_storage = None |
| 139 | + self.n_steps = n_steps |
138 | 140 |
|
139 | 141 | # Save train freq parameter, will be converted later to TrainFreq object |
140 | 142 | self.train_freq = train_freq |
@@ -176,6 +178,11 @@ def _setup_model(self) -> None: |
176 | 178 | if self.replay_buffer_class is None: |
177 | 179 | if isinstance(self.observation_space, spaces.Dict): |
178 | 180 | self.replay_buffer_class = DictReplayBuffer |
| 181 | + assert self.n_steps == 1, "N-step returns are not supported for Dict observation spaces yet." |
| 182 | + elif self.n_steps > 1: |
| 183 | + self.replay_buffer_class = NStepReplayBuffer |
| 184 | + # Add required arguments for computing n-step returns |
| 185 | + self.replay_buffer_kwargs.update({"n_steps": self.n_steps, "gamma": self.gamma}) |
179 | 186 | else: |
180 | 187 | self.replay_buffer_class = ReplayBuffer |
181 | 188 |
|
|
0 commit comments