diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index c04001c7c..598402569 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -90,7 +90,9 @@ SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). - the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator or pass options, - you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)`` and ``obs = vec_env.reset()`` afterward (seed and options are discarded after each call to ``reset()``). + you should call ``vec_env.seed(seed=seed)``/``vec_env.set_options(options)``. + Seed and options parameters will be passed to the next call to ``obs = vec_env.reset()`` and any implicit environment reset invoked by episode termination / truncation. + The provided seed and options will be discarded after each call to ``vec_env.reset()``. - methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c1560201c..8cf4428a3 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,39 @@ Changelog ========== + +Release 2.4.0a1 (WIP) +-------------------------- + +Breaking Changes: +^^^^^^^^^^^^^^^^^ + +New Features: +^^^^^^^^^^^^^ + +Bug Fixes: +^^^^^^^^^^ +- Fixed seed / options argument passing to environment resets in ``vec_env.reset()`` + +`SB3-Contrib`_ +^^^^^^^^^^^^^^ + +`RL Zoo`_ +^^^^^^^^^ + +`SBX`_ (SB3 + Jax) +^^^^^^^^^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + +Documentation: +^^^^^^^^^^^^^^ +- Expanded the description for vec_env.reset seed and options passing + Release 2.3.0 (2024-03-31) -------------------------- diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 15ecfb681..0b5e19835 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -67,7 +67,10 @@ def step_wait(self) -> VecEnvStepReturn: if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() + # reset the environment, supplying seed and options + seed = self._seeds[env_idx] + options = self._options[env_idx] + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=seed, options=options) self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index c598c735a..b0ef59f09 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -32,14 +32,15 @@ def _worker( try: cmd, data = remote.recv() if cmd == "step": - observation, reward, terminated, truncated, info = env.step(data) + action, seed, options = data + observation, reward, terminated, truncated, info = env.step(action) # convert to SB3 VecEnv api done = terminated or truncated info["TimeLimit.truncated"] = truncated and not terminated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation, reset_info = env.reset() + observation, reset_info = env.reset(seed=seed, options=options) remote.send((observation, reward, done, info, reset_info)) elif cmd == "reset": maybe_options = {"options": data[1]} if data[1] else {} @@ -121,8 +122,9 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ super().__init__(len(env_fns), observation_space, action_space) def step_async(self, actions: np.ndarray) -> None: - for remote, action in zip(self.remotes, actions): - remote.send(("step", action)) + for remote, action, seed, option in zip(self.remotes, actions, self._seeds, self._options): + # seed and option are used if step triggers a reset + remote.send(("step", (action, seed, option))) self.waiting = True def step_wait(self) -> VecEnvStepReturn: diff --git a/tests/test_logger.py b/tests/test_logger.py index dfd9e5567..8f5cf9ac5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -354,7 +354,7 @@ def __init__(self, delay: float = 0.01): self.observation_space = spaces.Box(low=-20.0, high=20.0, shape=(4,), dtype=np.float32) self.action_space = spaces.Discrete(2) - def reset(self, seed=None): + def reset(self, seed=None, options=None): return self.observation_space.sample(), {} def step(self, action): diff --git a/tests/test_predict.py b/tests/test_predict.py index 9a845232f..0ccd79a7c 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -30,7 +30,7 @@ def __init__(self): self.observation_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) - def reset(self, seed=None): + def reset(self, seed=None, options=None): return self.observation_space.sample(), {} def step(self, action):