From ad54b994a6af7dc0aa895e373e790bda23a30f06 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 28 Oct 2022 18:22:21 +0200 Subject: [PATCH 01/13] Add envpool support --- requirements.txt | 1 + rl_zoo3/exp_manager.py | 82 +++++++++++++++++++++++-------------- rl_zoo3/train.py | 9 +++- rl_zoo3/vec_env_wrappers.py | 42 +++++++++++++++++++ 4 files changed, 102 insertions(+), 32 deletions(-) create mode 100644 rl_zoo3/vec_env_wrappers.py diff --git a/requirements.txt b/requirements.txt index 9ef1cc5ca..288f6168f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,3 +18,4 @@ seaborn tqdm rich importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0 +envpool diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index d27e876ee..f724843e7 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -35,6 +35,7 @@ SubprocVecEnv, VecEnv, VecFrameStack, + VecMonitor, VecNormalize, VecTransposeImage, is_vecenv_wrapped, @@ -49,6 +50,13 @@ from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule +try: + import envpool + + from rl_zoo3.vec_env_wrappers import EnvPoolAdapter +except ImportError: + envpool = None + class ExperimentManager: """ @@ -95,6 +103,7 @@ def __init__( device: Union[th.device, str] = "auto", yaml_file: Optional[str] = None, show_progress: bool = False, + use_envpool: bool = False, ): super().__init__() self.algo = algo @@ -118,6 +127,7 @@ def __init__( self.seed = seed self.optimization_log_path = optimization_log_path + self.use_envpool = use_envpool self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] self.vec_env_wrapper = None @@ -559,37 +569,47 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) # Do not log eval env (issue with writing the same file) log_dir = None if eval_env or no_log else self.save_path - # Special case for GoalEnvs: log success rate too - if ( - "Neck" in self.env_name.gym_id - or self.is_robotics_env(self.env_name.gym_id) - or "parking-v0" in self.env_name.gym_id - and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs - ): - self.monitor_kwargs = dict(info_keywords=("is_success",)) - - # Define make_env here so it works with subprocesses - # when the registry was modified with `--gym-packages` - # See https://github.com/HumanCompatibleAI/imitation/pull/160 - spec = gym.spec(self.env_name.gym_id) - - def make_env(**kwargs) -> gym.Env: - env = spec.make(**kwargs) - return env - - # On most env, SubprocVecEnv does not help and is quite memory hungry - # therefore we use DummyVecEnv by default - env = make_vec_env( - make_env, - n_envs=n_envs, - seed=self.seed, - env_kwargs=self.env_kwargs, - monitor_dir=log_dir, - wrapper_class=self.env_wrapper, - vec_env_cls=self.vec_env_class, - vec_env_kwargs=self.vec_env_kwargs, - monitor_kwargs=self.monitor_kwargs, - ) + if self.use_envpool: + # TODO: warning if env wrapper is passed + # TODO: check that log10(self.seed) <= 9 + env = envpool.make(self.env_name.gym_id, env_type="gym", num_envs=n_envs, seed=self.seed) + env.spec.id = self.env_name.gym_id + env = EnvPoolAdapter(env) + filename = None if log_dir is None else f"{log_dir}/monitor.csv" + env = VecMonitor(env, filename, **self.monitor_kwargs) + + else: + # Special case for GoalEnvs: log success rate too + if ( + "Neck" in self.env_name.gym_id + or self.is_robotics_env(self.env_name.gym_id) + or "parking-v0" in self.env_name.gym_id + and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs + ): + self.monitor_kwargs = dict(info_keywords=("is_success",)) + + # Define make_env here so it works with subprocesses + # when the registry was modified with `--gym-packages` + # See https://github.com/HumanCompatibleAI/imitation/pull/160 + spec = gym.spec(self.env_name.gym_id) + + def make_env(**kwargs) -> gym.Env: + env = spec.make(**kwargs) + return env + + # On most env, SubprocVecEnv does not help and is quite memory hungry + # therefore we use DummyVecEnv by default + env = make_vec_env( + make_env, + n_envs=n_envs, + seed=self.seed, + env_kwargs=self.env_kwargs, + monitor_dir=log_dir, + wrapper_class=self.env_wrapper, + vec_env_cls=self.vec_env_class, + vec_env_kwargs=self.vec_env_kwargs, + monitor_kwargs=self.monitor_kwargs, + ) if self.vec_env_wrapper is not None: env = self.vec_env_wrapper(env) diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index bffd83663..03d499990 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -140,7 +140,13 @@ def train(): default=False, help="if toggled, display a progress bar using tqdm and rich", ) - + parser.add_argument( + "-envpool", + "--use-envpool", + action="store_true", + default=False, + help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.", + ) args = parser.parse_args() # Going through custom gym packages to let them register in the global registory @@ -236,6 +242,7 @@ def train(): device=args.device, yaml_file=args.yaml_file, show_progress=args.progress, + use_envpool=args.use_envpool, ) # Prepare experiment and launch hyperparameter optimization if needed diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py new file mode 100644 index 000000000..cc49360f1 --- /dev/null +++ b/rl_zoo3/vec_env_wrappers.py @@ -0,0 +1,42 @@ +from typing import Optional + +import numpy as np +from envpool.python.protocol import EnvPool +from stable_baselines3.common.vec_env import VecEnvWrapper +from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn + + +class EnvPoolAdapter(VecEnvWrapper): + """ + Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + + :param venv: The envpool object. + """ + + def __init__(self, venv: EnvPool): + # Retrieve the number of environments from the config + venv.num_envs = venv.spec.config.num_envs + super().__init__(venv=venv) + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def reset(self) -> VecEnvObs: + return self.venv.reset() + + def seed(self, seed: Optional[int] = None) -> None: + # You can only seed EnvPool env by calling envpool.make() + pass + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, dones, info_dict = self.venv.step(self.actions) + infos = [] + # Convert dict to list of dict + # and add terminal observation + for i in range(self.num_envs): + infos.append({key: info_dict[key][i] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) + if dones[i]: + infos[i]["terminal_observation"] = obs[i] + obs[i] = self.venv.reset(np.array([i])) + + return obs, rewards, dones, infos From 01712955c796c063886299729c6fb7e1e432f3b6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 28 Oct 2022 20:59:38 +0200 Subject: [PATCH 02/13] Update max seed for envpool --- rl_zoo3/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 03d499990..8549081c4 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -168,7 +168,8 @@ def train(): uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" if args.seed < 0: # Seed but with a random one - args.seed = np.random.randint(2**32 - 1, dtype="int64").item() + max_number = 2**31 - 1 if args.use_envpool else 2**32 - 1 + args.seed = np.random.randint(max_number, dtype="int64").item() set_random_seed(args.seed) From 594ddca11c5694f390a01592a48b4227aa507e3f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Fri, 28 Oct 2022 21:16:26 +0200 Subject: [PATCH 03/13] Tmp fix for SAC with float64 actions --- rl_zoo3/vec_env_wrappers.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py index cc49360f1..cc16ddaef 100644 --- a/rl_zoo3/vec_env_wrappers.py +++ b/rl_zoo3/vec_env_wrappers.py @@ -1,5 +1,6 @@ from typing import Optional +import gym import numpy as np from envpool.python.protocol import EnvPool from stable_baselines3.common.vec_env import VecEnvWrapper @@ -17,6 +18,9 @@ def __init__(self, venv: EnvPool): # Retrieve the number of environments from the config venv.num_envs = venv.spec.config.num_envs super().__init__(venv=venv) + # Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 + if isinstance(self.action_space, gym.spaces.Box) and self.action_space.dtype == np.float64: + self.action_space.dtype = np.dtype(np.float32) def step_async(self, actions: np.ndarray) -> None: self.actions = actions From 1bca590cf11a241c46b5a713119a7bd10592575d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Sat, 17 Dec 2022 13:18:54 +0100 Subject: [PATCH 04/13] Add support for envpool with Atari --- rl_zoo3/exp_manager.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index e3722db71..7250daba8 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -593,8 +593,13 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) if self.use_envpool: # TODO: warning if env wrapper is passed - # TODO: check that log10(self.seed) <= 9 - env = envpool.make(self.env_name.gym_id, env_type="gym", num_envs=n_envs, seed=self.seed) + # Convert Atari game names + # See https://github.com/sail-sg/envpool/issues/14 + env_id = self.env_name.gym_id + if self._is_atari and "NoFrameskip-v4" in env_id: + env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" + + env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed) env.spec.id = self.env_name.gym_id env = EnvPoolAdapter(env) filename = None if log_dir is None else f"{log_dir}/monitor.csv" From a452cadd6f210b5a74cf965256133a9615dc7f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Mon, 6 Feb 2023 10:36:50 +0100 Subject: [PATCH 05/13] Try to fix CI --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ea395d44c..cc2ed9de8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,6 +31,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip + python -m pip install setuptools==66.1 # cpu version of pytorch - faster to download pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install pybullet==3.1.9 From 9dc8a35fd51cbb7e7985bae2b512751a4058f890 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Mon, 6 Feb 2023 11:12:16 +0100 Subject: [PATCH 06/13] Downgrade setuptools to 65.5.0 in CI --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index cc2ed9de8..f40f951d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -31,7 +31,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install setuptools==66.1 + python -m pip install setuptools==65.5.0 # cpu version of pytorch - faster to download pip install torch==1.11+cpu -f https://download.pytorch.org/whl/torch_stable.html pip install pybullet==3.1.9 From 6b0620fc33f63fe4290ddd88ee4a534268aadbfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Mon, 6 Feb 2023 20:06:34 +0100 Subject: [PATCH 07/13] Drop vecenvwrapper in favor of vecenv --- rl_zoo3/vec_env_wrappers.py | 81 +++++++++++++++++++++++++++---------- 1 file changed, 59 insertions(+), 22 deletions(-) diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py index cc16ddaef..341750032 100644 --- a/rl_zoo3/vec_env_wrappers.py +++ b/rl_zoo3/vec_env_wrappers.py @@ -1,46 +1,83 @@ -from typing import Optional +import inspect +from typing import Any, Iterable, List, Optional, Sequence, Type, Union import gym import numpy as np from envpool.python.protocol import EnvPool -from stable_baselines3.common.vec_env import VecEnvWrapper +from gym import spaces +from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn +# Used when we want to access one or more VecEnv +VecEnvIndices = Union[None, int, Iterable[int]] -class EnvPoolAdapter(VecEnvWrapper): + +class EnvPoolAdapter(VecEnv): """ - Convert EnvPool object to a Stable-Baselines3 (SB3) VecEnv. + Vectorized environment base class - :param venv: The envpool object. + :param venv: the vectorized environment to wrap + :param observation_space: the observation space (can be None to load from venv) + :param action_space: the action space (can be None to load from venv) """ - def __init__(self, venv: EnvPool): + def __init__(self, venv: EnvPool) -> None: # Retrieve the number of environments from the config - venv.num_envs = venv.spec.config.num_envs - super().__init__(venv=venv) + self.venv = venv + action_space = venv.action_space + observation_space = venv.observation_space # Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 - if isinstance(self.action_space, gym.spaces.Box) and self.action_space.dtype == np.float64: - self.action_space.dtype = np.dtype(np.float32) + for space in [observation_space, action_space]: + if isinstance(space, spaces.Box) and space.dtype == np.float64: + space.dtype = np.dtype(np.float32) - def step_async(self, actions: np.ndarray) -> None: - self.actions = actions + super().__init__( + num_envs=venv.spec.config.num_envs, + observation_space=observation_space, + action_space=action_space, + ) + self.class_attributes = dict(inspect.getmembers(self.__class__)) # TODO: unused def reset(self) -> VecEnvObs: return self.venv.reset() - def seed(self, seed: Optional[int] = None) -> None: - # You can only seed EnvPool env by calling envpool.make() - pass + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions def step_wait(self) -> VecEnvStepReturn: obs, rewards, dones, info_dict = self.venv.step(self.actions) infos = [] - # Convert dict to list of dict - # and add terminal observation - for i in range(self.num_envs): - infos.append({key: info_dict[key][i] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) - if dones[i]: - infos[i]["terminal_observation"] = obs[i] - obs[i] = self.venv.reset(np.array([i])) + # Convert dict to list of dict and add terminal observation + for env_idx in range(self.num_envs): + infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) + if dones[env_idx]: + infos[env_idx]["terminal_observation"] = obs[env_idx] + obs[env_idx] = self.venv.reset(np.array([env_idx])) return obs, rewards, dones, infos + + def close(self) -> None: + # No closing method in envpool + return None + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + raise NotImplementedError() # TODO: Does envpool support this? + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + raise NotImplementedError() # TODO: Does envpool support this? + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + raise NotImplementedError() # TODO: Does envpool support this? + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + return [False for _ in range(self.num_envs)] + + def get_images(self) -> Sequence[np.ndarray]: + raise NotImplementedError() # TODO: Does envpool support this? + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + raise NotImplementedError() # TODO: Does envpool support this? + + def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed + # You can only seed EnvPool env by calling envpool.make() + return [None for _ in range(self.num_envs)] From 0a992f1eec123c2d5d45f052ae3989e05c1e2563 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Thu, 9 Feb 2023 16:41:00 +0100 Subject: [PATCH 08/13] NotImplementedError for vecenv methods --- rl_zoo3/vec_env_wrappers.py | 44 ++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py index 341750032..28468a091 100644 --- a/rl_zoo3/vec_env_wrappers.py +++ b/rl_zoo3/vec_env_wrappers.py @@ -1,4 +1,3 @@ -import inspect from typing import Any, Iterable, List, Optional, Sequence, Type, Union import gym @@ -12,31 +11,41 @@ VecEnvIndices = Union[None, int, Iterable[int]] +def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space: + """ + Convert the dtype of a space to float32. + + :param space: Space to convert + :return: Converted space + """ + if isinstance(space, spaces.Box): + space.dtype = np.dtype(np.float32) + elif isinstance(space, spaces.Dict): + for key, sub_space in space.spaces.items(): + space.spaces[key] = _convert_dtype_to_float32(sub_space) + return space + + class EnvPoolAdapter(VecEnv): """ - Vectorized environment base class + Wrapper for EnvPool to make it compatible with Stable-Baselines3. - :param venv: the vectorized environment to wrap - :param observation_space: the observation space (can be None to load from venv) - :param action_space: the action space (can be None to load from venv) + :param venv: EnvPool environment """ def __init__(self, venv: EnvPool) -> None: - # Retrieve the number of environments from the config self.venv = venv action_space = venv.action_space observation_space = venv.observation_space # Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 - for space in [observation_space, action_space]: - if isinstance(space, spaces.Box) and space.dtype == np.float64: - space.dtype = np.dtype(np.float32) + observation_space = _convert_dtype_to_float32(observation_space) + action_space = _convert_dtype_to_float32(action_space) super().__init__( - num_envs=venv.spec.config.num_envs, + num_envs=venv.spec.config.num_envs, # Retrieve the number of environments from the config observation_space=observation_space, action_space=action_space, ) - self.class_attributes = dict(inspect.getmembers(self.__class__)) # TODO: unused def reset(self) -> VecEnvObs: return self.venv.reset() @@ -57,26 +66,25 @@ def step_wait(self) -> VecEnvStepReturn: return obs, rewards, dones, infos def close(self) -> None: - # No closing method in envpool - return None + return None # No closing method in envpool def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: - raise NotImplementedError() # TODO: Does envpool support this? + raise NotImplementedError("EnvPool does not support get_attr()") def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: - raise NotImplementedError() # TODO: Does envpool support this? + raise NotImplementedError("EnvPool does not support set_attr()") def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: - raise NotImplementedError() # TODO: Does envpool support this? + raise NotImplementedError("EnvPool does not support env_method()") def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: return [False for _ in range(self.num_envs)] def get_images(self) -> Sequence[np.ndarray]: - raise NotImplementedError() # TODO: Does envpool support this? + raise NotImplementedError("EnvPool does not support get_images()") def render(self, mode: str = "human") -> Optional[np.ndarray]: - raise NotImplementedError() # TODO: Does envpool support this? + raise NotImplementedError("EnvPool does not support render()") def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed # You can only seed EnvPool env by calling envpool.make() From c5f38e1e4d47a8e7007dc81902a5ba9c7db86b3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Thu, 9 Feb 2023 16:41:15 +0100 Subject: [PATCH 09/13] vec_env_kargs to envpool --- rl_zoo3/exp_manager.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 1ab64a63d..e35bfacbd 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -54,10 +54,11 @@ try: import envpool - - from rl_zoo3.vec_env_wrappers import EnvPoolAdapter except ImportError: envpool = None +else: + from rl_zoo3.vec_env_wrappers import EnvPoolAdapter + from envpool.python.protocol import EnvPool class ExperimentManager: @@ -592,14 +593,21 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) log_dir = None if eval_env or no_log else self.save_path if self.use_envpool: - # TODO: warning if env wrapper is passed + if self.env_wrapper is not None: + warnings.warn("EnvPool does not support env wrappers, it will be ignored.") + if self.env_kwargs is not None: + warnings.warn( + "EnvPool does not support env_kwargs, it will be ignored. " + "To pass keyword argument to envpool, use vec_env_kwargs instead." + ) + # Convert Atari game names # See https://github.com/sail-sg/envpool/issues/14 env_id = self.env_name.gym_id if self._is_atari and "NoFrameskip-v4" in env_id: env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" - env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed) + env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs) # type: EnvPool env.spec.id = self.env_name.gym_id env = EnvPoolAdapter(env) filename = None if log_dir is None else f"{log_dir}/monitor.csv" From 94f119940b897c56fbb1672abe838e36a229d9a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Thu, 9 Feb 2023 16:41:40 +0100 Subject: [PATCH 10/13] isort --- rl_zoo3/exp_manager.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index e35bfacbd..800636418 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -57,9 +57,10 @@ except ImportError: envpool = None else: - from rl_zoo3.vec_env_wrappers import EnvPoolAdapter from envpool.python.protocol import EnvPool + from rl_zoo3.vec_env_wrappers import EnvPoolAdapter + class ExperimentManager: """ From 16cec76434dff33943366cc7d12c555105220571 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Thu, 9 Feb 2023 18:14:03 +0100 Subject: [PATCH 11/13] fix env_kwarg warning condition --- rl_zoo3/exp_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 800636418..2e74636c1 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -596,7 +596,7 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) if self.use_envpool: if self.env_wrapper is not None: warnings.warn("EnvPool does not support env wrappers, it will be ignored.") - if self.env_kwargs is not None: + if self.env_kwargs: warnings.warn( "EnvPool does not support env_kwargs, it will be ignored. " "To pass keyword argument to envpool, use vec_env_kwargs instead." From 4265e9f22cc033e88d4afc1be16f03cd7ea260aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Thu, 9 Feb 2023 18:14:17 +0100 Subject: [PATCH 12/13] fix space type conveter --- rl_zoo3/vec_env_wrappers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py index 28468a091..f6034f056 100644 --- a/rl_zoo3/vec_env_wrappers.py +++ b/rl_zoo3/vec_env_wrappers.py @@ -19,7 +19,8 @@ def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space: :return: Converted space """ if isinstance(space, spaces.Box): - space.dtype = np.dtype(np.float32) + if space.dtype == np.float64: + space.dtype = np.dtype(np.float32) elif isinstance(space, spaces.Dict): for key, sub_space in space.spaces.items(): space.spaces[key] = _convert_dtype_to_float32(sub_space) From 6f161f8b7c245cb284fad684a3284d74fb53d492 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20GALLOU=C3=89DEC?= Date: Mon, 11 Dec 2023 09:44:44 +0100 Subject: [PATCH 13/13] update to gymnasium --- rl_zoo3/exp_manager.py | 4 +++- rl_zoo3/vec_env_wrappers.py | 25 +++++++++++++++++++------ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 3552ba4a8..719aad91d 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -624,7 +624,9 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) if self._is_atari and "NoFrameskip-v4" in env_id: env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" - env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs) # type: EnvPool + env = envpool.make( + env_id, env_type="gymnasium", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs + ) # type: EnvPool env.spec.id = self.env_name.gym_id env = EnvPoolAdapter(env) filename = None if log_dir is None else f"{log_dir}/monitor.csv" diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py index f6034f056..86bb9c6ce 100644 --- a/rl_zoo3/vec_env_wrappers.py +++ b/rl_zoo3/vec_env_wrappers.py @@ -1,9 +1,9 @@ from typing import Any, Iterable, List, Optional, Sequence, Type, Union -import gym +import gymnasium as gym import numpy as np from envpool.python.protocol import EnvPool -from gym import spaces +from gymnasium import spaces from stable_baselines3.common.vec_env import VecEnv from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn @@ -49,21 +49,34 @@ def __init__(self, venv: EnvPool) -> None: ) def reset(self) -> VecEnvObs: - return self.venv.reset() + obs, reset_infos = self.venv.reset() + for key, value in reset_infos.items(): + if key == "players": # only used for multi-agent setting + continue + for env_idx in range(self.num_envs): + self.reset_infos[env_idx][key] = value[env_idx] + return obs def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: - obs, rewards, dones, info_dict = self.venv.step(self.actions) + obs, rewards, terminated, truncated, info_dict = self.venv.step(self.actions) + dones = terminated | truncated + infos = [] # Convert dict to list of dict and add terminal observation for env_idx in range(self.num_envs): infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) + infos[env_idx]["TimeLimit.truncated"] = truncated[env_idx] and not terminated[env_idx] if dones[env_idx]: infos[env_idx]["terminal_observation"] = obs[env_idx] - obs[env_idx] = self.venv.reset(np.array([env_idx])) - + obs[env_idx], reset_infos = self.venv.reset(np.array([env_idx])) + # Store reset_infos + for key, value in reset_infos.items(): + if key == "players": + continue + self.reset_infos[env_idx][key] = value[0] return obs, rewards, dones, infos def close(self) -> None: