Skip to content

Commit fc09267

Browse files
committed
Fix discounts for early terminations and fix reward normalization
1 parent cd362ea commit fc09267

File tree

6 files changed

+40
-25
lines changed

6 files changed

+40
-25
lines changed

stable_baselines3/common/buffers.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -894,18 +894,18 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
894894
# Randomly choose env indices for each sample
895895
env_indices = np.random.randint(0, self.n_envs, size=batch_inds.shape)
896896

897-
# Compute n-step indices with wrap-around
898-
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
899897
# Note: the self.pos index is dangerous (will overlap two different episodes when buffer is full)
900898
# so we set self.pos-1 to truncated=True (temporarily) if done=False
901899
# TODO: avoid copying the whole array (requires some more indices trickery)
902900
safe_timeouts = self.timeouts.copy()
903901
safe_timeouts[self.pos - 1, :] = np.logical_not(self.dones[self.pos - 1, :])
904902

903+
# Compute n-step indices with wrap-around
904+
steps = np.arange(self.n_steps).reshape(1, -1) # shape: [1, n_steps]
905905
indices = (batch_inds[:, None] + steps) % self.buffer_size # shape: [batch, n_steps]
906906

907907
# Retrieve sequences of transitions
908-
rewards_seq = self.rewards[indices, env_indices[:, None]] # [batch, n_steps]
908+
rewards_seq = self._normalize_reward(self.rewards[indices, env_indices[:, None]], env) # [batch, n_steps]
909909
dones_seq = self.dones[indices, env_indices[:, None]] # [batch, n_steps]
910910
truncs_seq = safe_timeouts[indices, env_indices[:, None]] # [batch, n_steps]
911911

@@ -917,6 +917,9 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
917917
done_idx = np.where(has_done_or_trunc, done_idx, self.n_steps - 1)
918918

919919
mask = np.arange(self.n_steps).reshape(1, -1) <= done_idx[:, None] # shape: [batch, n_steps]
920+
# Compute discount factors for bootstrapping (using target Q-Value)
921+
# It is gamma ** n_steps by default but should be adjusted in case of early termination/truncation.
922+
target_q_discounts = self.gamma ** mask.sum(axis=1, keepdims=True).astype(np.float32) # [batch, 1]
920923

921924
# Apply discount
922925
discounts = self.gamma ** np.arange(self.n_steps, dtype=np.float32).reshape(1, -1) # [1, n_steps]
@@ -939,6 +942,6 @@ def _get_samples(self, batch_inds: np.ndarray, env: Optional[VecNormalize] = Non
939942
actions=self.to_torch(actions),
940943
next_observations=self.to_torch(next_obs), # type: ignore[arg-type]
941944
dones=self.to_torch(final_dones),
942-
# FIXME: what to do with self._normalize_reward ?
943945
rewards=self.to_torch(n_step_returns),
946+
discounts=self.to_torch(target_q_discounts),
944947
)

stable_baselines3/common/type_aliases.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ class ReplayBufferSamples(NamedTuple):
5252
next_observations: th.Tensor
5353
dones: th.Tensor
5454
rewards: th.Tensor
55+
# For n-step replay buffer
56+
discounts: Optional[th.Tensor] = None
5557

5658

5759
class DictReplayBufferSamples(NamedTuple):
@@ -60,6 +62,7 @@ class DictReplayBufferSamples(NamedTuple):
6062
next_observations: TensorDict
6163
dones: th.Tensor
6264
rewards: th.Tensor
65+
discounts: Optional[th.Tensor] = None
6366

6467

6568
class RolloutReturn(NamedTuple):

stable_baselines3/dqn/dqn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
191191
for _ in range(gradient_steps):
192192
# Sample replay buffer
193193
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
194+
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
195+
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
194196

195197
with th.no_grad():
196198
# Compute the next Q-values using the target network
@@ -200,7 +202,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
200202
# Avoid potential broadcast issue
201203
next_q_values = next_q_values.reshape(-1, 1)
202204
# 1-step TD target
203-
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
205+
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
204206

205207
# Get current Q-values estimates
206208
current_q_values = self.q_net(replay_data.observations)

stable_baselines3/sac/sac.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,8 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
213213
for gradient_step in range(gradient_steps):
214214
# Sample replay buffer
215215
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
216+
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
217+
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
216218

217219
# We need to sample because `log_std` may have changed between two gradient steps
218220
if self.use_sde:
@@ -252,7 +254,7 @@ def train(self, gradient_steps: int, batch_size: int = 64) -> None:
252254
# add entropy term
253255
next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
254256
# td error + entropy term
255-
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
257+
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
256258

257259
# Get current Q-values estimates for each critic network
258260
# using action from the replay buffer

stable_baselines3/td3/td3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
163163
self._n_updates += 1
164164
# Sample replay buffer
165165
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) # type: ignore[union-attr]
166+
# For n-step replay, discount factor is gamma**n_steps (when no early termination)
167+
discounts = replay_data.discounts if replay_data.discounts is not None else self.gamma
166168

167169
with th.no_grad():
168170
# Select action according to policy and add clipped noise
@@ -173,7 +175,7 @@ def train(self, gradient_steps: int, batch_size: int = 100) -> None:
173175
# Compute the next Q-values: min over all critics targets
174176
next_q_values = th.cat(self.critic_target(replay_data.next_observations, next_actions), dim=1)
175177
next_q_values, _ = th.min(next_q_values, dim=1, keepdim=True)
176-
target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values
178+
target_q_values = replay_data.rewards + (1 - replay_data.dones) * discounts * next_q_values
177179

178180
# Get current Q-values estimates for each critic network
179181
current_q_values = self.critic(replay_data.observations, replay_data.actions)

tests/test_n_step_replay.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@ def test_run(model_class):
1212
env_id = "CartPole-v1" if model_class == DQN else "Pendulum-v1"
1313
env = make_vec_env(env_id, n_envs=2)
1414

15-
# FIXME: need to set the discount factor manually
1615
n_steps = 2
1716
gamma = 0.99
18-
discount = gamma**n_steps
1917

2018
model = model_class(
2119
"MlpPolicy",
@@ -29,7 +27,7 @@ def test_run(model_class):
2927
policy_kwargs=dict(net_arch=[64]),
3028
learning_starts=100,
3129
buffer_size=int(2e4),
32-
gamma=discount,
30+
gamma=gamma,
3331
)
3432

3533
model.learn(total_timesteps=150)
@@ -103,11 +101,11 @@ def test_nstep_early_termination(done_at, n_steps):
103101

104102
base_idx = 0
105103
batch = buffer._get_samples(np.array([base_idx]))
106-
actual = batch.rewards.numpy().item()
104+
actual = batch.rewards.item()
107105

108106
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=done_at - base_idx)
109107
np.testing.assert_allclose(actual, expected, rtol=1e-4)
110-
assert batch.dones.numpy().item() == 1.0
108+
assert batch.dones.item() == 1.0
111109

112110

113111
@pytest.mark.parametrize("truncated_at", [1, 2])
@@ -117,46 +115,51 @@ def test_nstep_early_truncation(truncated_at):
117115

118116
base_idx = 0
119117
batch = buffer._get_samples(np.array([base_idx]))
120-
actual = batch.rewards.numpy().item()
118+
actual = batch.rewards.item()
121119

122120
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=3, stop_idx=truncated_at - base_idx)
123121
np.testing.assert_allclose(actual, expected, rtol=1e-4)
124-
assert batch.dones.numpy().item() == 0.0
122+
assert batch.dones.item() == 0.0
125123

126124

127125
@pytest.mark.parametrize("n_steps", [3, 5])
128-
def test_nstep_no_termination_or_truncation(n_steps):
126+
def test_nstep_no_terminations(n_steps):
129127
buffer = create_buffer(n_steps=n_steps)
130128
fill_buffer(buffer, length=10) # no done or truncation
129+
gamma = 0.99
131130

132131
base_idx = 3
133132
batch = buffer._get_samples(np.array([base_idx]))
134-
actual = batch.rewards.numpy().item()
135-
136-
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps)
133+
actual = batch.rewards.item()
134+
# Discount factor for bootstrapping with target Q-Value
135+
np.testing.assert_allclose(batch.discounts.item(), gamma**n_steps)
136+
expected = compute_expected_nstep_reward(gamma=gamma, n_steps=n_steps)
137137
np.testing.assert_allclose(actual, expected, rtol=1e-4)
138-
assert batch.dones.numpy().item() == 0.0
138+
assert batch.dones.item() == 0.0
139139

140140
# Check that self.pos-1 truncation is set when buffer is full
141141
# Note: buffer size is 10, here we are erasing past transitions
142142
fill_buffer(buffer, length=2)
143143
# We create a tmp truncation to not sample across episodes
144144
base_idx = 0
145145
batch = buffer._get_samples(np.array([base_idx]))
146-
actual = batch.rewards.numpy().item()
146+
actual = batch.rewards.item()
147147
# Note: compute_expected_nstep assumes base_idx=1
148148
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1)
149149
np.testing.assert_allclose(actual, expected, rtol=1e-4)
150-
assert batch.dones.numpy().item() == 0.0
150+
assert batch.dones.item() == 0.0
151+
# Discount factor for bootstrapping with target Q-Value
152+
# (bigger than gamma ** n_steps because of truncation at n_steps=2)
153+
np.testing.assert_allclose(batch.discounts.item(), gamma**2)
151154

152155
# Set done=1 manually, the tmp truncation should not be set (it would set batch.done=False)
153156
buffer.dones[buffer.pos - 1, :] = True
154157
batch = buffer._get_samples(np.array([base_idx]))
155-
actual = batch.rewards.numpy().item()
158+
actual = batch.rewards.item()
156159
# Note: compute_expected_nstep assumes base_idx=0
157160
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=n_steps, stop_idx=buffer.pos - 1)
158161
np.testing.assert_allclose(actual, expected, rtol=1e-4)
159-
assert batch.dones.numpy().item() == 1.0
162+
assert batch.dones.item() == 1.0
160163

161164

162165
def test_match_normal_buffer():
@@ -168,12 +171,12 @@ def test_match_normal_buffer():
168171

169172
base_idx = 3
170173
batch1 = buffer._get_samples(np.array([base_idx]))
171-
actual1 = batch1.rewards.numpy().item()
174+
actual1 = batch1.rewards.item()
172175

173176
batch2 = ref_buffer._get_samples(np.array([base_idx]))
174177

175178
expected = compute_expected_nstep_reward(gamma=0.99, n_steps=1)
176179
np.testing.assert_allclose(actual1, expected, rtol=1e-4)
177-
assert batch1.dones.numpy().item() == 0.0
180+
assert batch1.dones.item() == 0.0
178181

179182
np.testing.assert_allclose(batch1.rewards.numpy(), batch2.rewards.numpy(), rtol=1e-4)

0 commit comments

Comments
 (0)