diff --git a/requirements.txt b/requirements.txt index 85fa79507..34ddf949c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,5 +19,6 @@ huggingface_sb3>=3.0,<4.0 seaborn tqdm rich +envpool moviepy ruff diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index b61786f72..719aad91d 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -37,6 +37,7 @@ SubprocVecEnv, VecEnv, VecFrameStack, + VecMonitor, VecNormalize, VecTransposeImage, is_vecenv_wrapped, @@ -51,6 +52,15 @@ from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule +try: + import envpool +except ImportError: + envpool = None +else: + from envpool.python.protocol import EnvPool + + from rl_zoo3.vec_env_wrappers import EnvPoolAdapter + class ExperimentManager: """ @@ -98,6 +108,7 @@ def __init__( device: Union[th.device, str] = "auto", config: Optional[str] = None, show_progress: bool = False, + use_envpool: bool = False, ): super().__init__() self.algo = algo @@ -121,6 +132,7 @@ def __init__( self.seed = seed self.optimization_log_path = optimization_log_path + self.use_envpool = use_envpool self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] self.vec_env_wrapper: Optional[Callable] = None @@ -598,38 +610,52 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) # Do not log eval env (issue with writing the same file) log_dir = None if eval_env or no_log else self.save_path - # Special case for GoalEnvs: log success rate too - if ( - "Neck" in self.env_name.gym_id - or self.is_robotics_env(self.env_name.gym_id) - or "parking-v0" in self.env_name.gym_id - and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs - ): - self.monitor_kwargs = dict(info_keywords=("is_success",)) - - spec = gym.spec(self.env_name.gym_id) - - # Define make_env here, so it works with subprocesses - # when the registry was modified with `--gym-packages` - # See https://github.com/HumanCompatibleAI/imitation/pull/160 - def make_env(**kwargs) -> gym.Env: - return spec.make(**kwargs) - - env_kwargs = self.eval_env_kwargs if eval_env else self.env_kwargs - - # On most env, SubprocVecEnv does not help and is quite memory hungry, - # therefore, we use DummyVecEnv by default - env = make_vec_env( - make_env, - n_envs=n_envs, - seed=self.seed, - env_kwargs=env_kwargs, - monitor_dir=log_dir, - wrapper_class=self.env_wrapper, - vec_env_cls=self.vec_env_class, # type: ignore[arg-type] - vec_env_kwargs=self.vec_env_kwargs, - monitor_kwargs=self.monitor_kwargs, - ) + if self.use_envpool: + if self.env_wrapper is not None: + warnings.warn("EnvPool does not support env wrappers, it will be ignored.") + if self.env_kwargs: + warnings.warn( + "EnvPool does not support env_kwargs, it will be ignored. " + "To pass keyword argument to envpool, use vec_env_kwargs instead." + ) + # Convert Atari game names + # See https://github.com/sail-sg/envpool/issues/14 + env_id = self.env_name.gym_id + if self._is_atari and "NoFrameskip-v4" in env_id: + env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" + + env = envpool.make( + env_id, env_type="gymnasium", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs + ) # type: EnvPool + env.spec.id = self.env_name.gym_id + env = EnvPoolAdapter(env) + filename = None if log_dir is None else f"{log_dir}/monitor.csv" + env = VecMonitor(env, filename, **self.monitor_kwargs) + + else: + spec = gym.spec(self.env_name.gym_id) + + # Define make_env here, so it works with subprocesses + # when the registry was modified with `--gym-packages` + # See https://github.com/HumanCompatibleAI/imitation/pull/160 + def make_env(**kwargs) -> gym.Env: + return spec.make(**kwargs) + + env_kwargs = self.eval_env_kwargs if eval_env else self.env_kwargs + + # On most env, SubprocVecEnv does not help and is quite memory hungry, + # therefore, we use DummyVecEnv by default + env = make_vec_env( + make_env, + n_envs=n_envs, + seed=self.seed, + env_kwargs=env_kwargs, + monitor_dir=log_dir, + wrapper_class=self.env_wrapper, + vec_env_cls=self.vec_env_class, # type: ignore[arg-type] + vec_env_kwargs=self.vec_env_kwargs, + monitor_kwargs=self.monitor_kwargs, + ) if self.vec_env_wrapper is not None: env = self.vec_env_wrapper(env) diff --git a/rl_zoo3/train.py b/rl_zoo3/train.py index 53be1683b..44d2253ff 100644 --- a/rl_zoo3/train.py +++ b/rl_zoo3/train.py @@ -153,6 +153,13 @@ def train() -> None: default=False, help="if toggled, display a progress bar using tqdm and rich", ) + parser.add_argument( + "-envpool", + "--use-envpool", + action="store_true", + default=False, + help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.", + ) parser.add_argument( "-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123" ) @@ -178,7 +185,7 @@ def train() -> None: uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" if args.seed < 0: # Seed but with a random one - args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined] + args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined] set_random_seed(args.seed) @@ -255,6 +262,7 @@ def train() -> None: device=args.device, config=args.conf_file, show_progress=args.progress, + use_envpool=args.use_envpool, ) # Prepare experiment and launch hyperparameter optimization if needed diff --git a/rl_zoo3/vec_env_wrappers.py b/rl_zoo3/vec_env_wrappers.py new file mode 100644 index 000000000..86bb9c6ce --- /dev/null +++ b/rl_zoo3/vec_env_wrappers.py @@ -0,0 +1,105 @@ +from typing import Any, Iterable, List, Optional, Sequence, Type, Union + +import gymnasium as gym +import numpy as np +from envpool.python.protocol import EnvPool +from gymnasium import spaces +from stable_baselines3.common.vec_env import VecEnv +from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn + +# Used when we want to access one or more VecEnv +VecEnvIndices = Union[None, int, Iterable[int]] + + +def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space: + """ + Convert the dtype of a space to float32. + + :param space: Space to convert + :return: Converted space + """ + if isinstance(space, spaces.Box): + if space.dtype == np.float64: + space.dtype = np.dtype(np.float32) + elif isinstance(space, spaces.Dict): + for key, sub_space in space.spaces.items(): + space.spaces[key] = _convert_dtype_to_float32(sub_space) + return space + + +class EnvPoolAdapter(VecEnv): + """ + Wrapper for EnvPool to make it compatible with Stable-Baselines3. + + :param venv: EnvPool environment + """ + + def __init__(self, venv: EnvPool) -> None: + self.venv = venv + action_space = venv.action_space + observation_space = venv.observation_space + # Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 + observation_space = _convert_dtype_to_float32(observation_space) + action_space = _convert_dtype_to_float32(action_space) + + super().__init__( + num_envs=venv.spec.config.num_envs, # Retrieve the number of environments from the config + observation_space=observation_space, + action_space=action_space, + ) + + def reset(self) -> VecEnvObs: + obs, reset_infos = self.venv.reset() + for key, value in reset_infos.items(): + if key == "players": # only used for multi-agent setting + continue + for env_idx in range(self.num_envs): + self.reset_infos[env_idx][key] = value[env_idx] + return obs + + def step_async(self, actions: np.ndarray) -> None: + self.actions = actions + + def step_wait(self) -> VecEnvStepReturn: + obs, rewards, terminated, truncated, info_dict = self.venv.step(self.actions) + dones = terminated | truncated + + infos = [] + # Convert dict to list of dict and add terminal observation + for env_idx in range(self.num_envs): + infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) + infos[env_idx]["TimeLimit.truncated"] = truncated[env_idx] and not terminated[env_idx] + if dones[env_idx]: + infos[env_idx]["terminal_observation"] = obs[env_idx] + obs[env_idx], reset_infos = self.venv.reset(np.array([env_idx])) + # Store reset_infos + for key, value in reset_infos.items(): + if key == "players": + continue + self.reset_infos[env_idx][key] = value[0] + return obs, rewards, dones, infos + + def close(self) -> None: + return None # No closing method in envpool + + def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: + raise NotImplementedError("EnvPool does not support get_attr()") + + def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: + raise NotImplementedError("EnvPool does not support set_attr()") + + def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: + raise NotImplementedError("EnvPool does not support env_method()") + + def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: + return [False for _ in range(self.num_envs)] + + def get_images(self) -> Sequence[np.ndarray]: + raise NotImplementedError("EnvPool does not support get_images()") + + def render(self, mode: str = "human") -> Optional[np.ndarray]: + raise NotImplementedError("EnvPool does not support render()") + + def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed + # You can only seed EnvPool env by calling envpool.make() + return [None for _ in range(self.num_envs)]