Skip to content

Commit c5dee1a

Browse files
committed
Fixes for Numpy v2
1 parent b785624 commit c5dee1a

File tree

3 files changed

+9
-3
lines changed

3 files changed

+9
-3
lines changed

docs/misc/changelog.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ Breaking Changes:
1313

1414
New Features:
1515
^^^^^^^^^^^^^
16-
- Added support for NumPy v2.0 (via Torch)
16+
- Added support for NumPy v2.0: ``VecNormalize`` now cast normalized rewards to float32,
17+
updated bit flipping env to avoid overflow issues too
1718
- Added official support for Python 3.12
1819

1920
Bug Fixes:

stable_baselines3/common/envs/bit_flipping_env.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,17 @@ def convert_if_needed(self, state: np.ndarray) -> Union[int, np.ndarray]:
7575
:param state:
7676
:return:
7777
"""
78+
7879
if self.discrete_obs_space:
80+
# Convert from int8 to int32 for NumPy 2.0
81+
state = state.astype(np.int32)
7982
# The internal state is the binary representation of the
8083
# observed one
8184
return int(sum(state[i] * 2**i for i in range(len(state))))
8285

8386
if self.image_obs_space:
8487
size = np.prod(self.image_shape)
85-
image = np.concatenate((state * 255, np.zeros(size - len(state), dtype=np.uint8)))
88+
image = np.concatenate((state.astype(np.uint8) * 255, np.zeros(size - len(state), dtype=np.uint8)))
8689
return image.reshape(self.image_shape).astype(np.uint8)
8790
return state
8891

stable_baselines3/common/vec_env/vec_normalize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,9 @@ def normalize_reward(self, reward: np.ndarray) -> np.ndarray:
254254
"""
255255
if self.norm_reward:
256256
reward = np.clip(reward / np.sqrt(self.ret_rms.var + self.epsilon), -self.clip_reward, self.clip_reward)
257-
return reward
257+
# Note: we cast to float32 as it correspond to Python default float type
258+
# This cast is needed because `RunningMeanStd` keeps stats in float64
259+
return reward.astype(np.float32)
258260

259261
def unnormalize_obs(self, obs: Union[np.ndarray, dict[str, np.ndarray]]) -> Union[np.ndarray, dict[str, np.ndarray]]:
260262
# Avoid modifying by reference the original object

0 commit comments

Comments
 (0)