Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EnvPool support (as VecEnv subclass) #355

Open
wants to merge 26 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ad54b99
Add envpool support
araffin Oct 28, 2022
0171295
Update max seed for envpool
araffin Oct 28, 2022
594ddca
Tmp fix for SAC with float64 actions
araffin Oct 28, 2022
36c30fd
Merge branch 'master' into feat/envpool
araffin Nov 16, 2022
d8bf984
Merge branch 'master' into feat/envpool
araffin Nov 30, 2022
35b32c0
Merge branch 'master' into feat/envpool
araffin Dec 17, 2022
6a94d26
Merge branch 'master' into feat/envpool
araffin Dec 17, 2022
1bca590
Add support for envpool with Atari
araffin Dec 17, 2022
e457d3d
Merge branch 'master' into feat/envpool
araffin Jan 12, 2023
73d6bce
Merge branch 'master' into feat/envpool
qgallouedec Jan 25, 2023
aec9529
Merge branch 'master' into feat/envpool
araffin Feb 2, 2023
a452cad
Try to fix CI
qgallouedec Feb 6, 2023
9dc8a35
Downgrade setuptools to 65.5.0 in CI
qgallouedec Feb 6, 2023
ac532e8
Merge branch 'master' into feat/envpool
qgallouedec Feb 6, 2023
6b0620f
Drop vecenvwrapper in favor of vecenv
qgallouedec Feb 6, 2023
0a992f1
NotImplementedError for vecenv methods
qgallouedec Feb 9, 2023
c5f38e1
vec_env_kargs to envpool
qgallouedec Feb 9, 2023
94f1199
isort
qgallouedec Feb 9, 2023
16cec76
fix env_kwarg warning condition
qgallouedec Feb 9, 2023
4265e9f
fix space type conveter
qgallouedec Feb 9, 2023
a63f849
Merge branch 'master' into feat/envpool_with_vecenv
qgallouedec Feb 9, 2023
3961b17
Merge branch 'master' into feat/envpool_with_vecenv
araffin Mar 2, 2023
0946fc1
Merge branch 'master' into feat/envpool_with_vecenv
araffin Mar 29, 2023
c94924e
Merge branch 'master' into feat/envpool_with_vecenv
araffin Nov 16, 2023
a7bf243
Merge branch 'master' into feat/envpool_with_vecenv
qgallouedec Dec 6, 2023
6f161f8
update to gymnasium
qgallouedec Dec 11, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ seaborn
tqdm
rich
importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0
envpool
96 changes: 65 additions & 31 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
SubprocVecEnv,
VecEnv,
VecFrameStack,
VecMonitor,
VecNormalize,
VecTransposeImage,
is_vecenv_wrapped,
Expand All @@ -51,6 +52,15 @@
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule

try:
import envpool
except ImportError:
envpool = None
else:
from envpool.python.protocol import EnvPool

from rl_zoo3.vec_env_wrappers import EnvPoolAdapter


class ExperimentManager:
"""
Expand Down Expand Up @@ -97,6 +107,7 @@ def __init__(
device: Union[th.device, str] = "auto",
config: Optional[str] = None,
show_progress: bool = False,
use_envpool: bool = False,
):
super().__init__()
self.algo = algo
Expand All @@ -120,6 +131,7 @@ def __init__(
self.seed = seed
self.optimization_log_path = optimization_log_path

self.use_envpool = use_envpool
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type]
self.vec_env_wrapper = None

Expand Down Expand Up @@ -581,37 +593,59 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
# Do not log eval env (issue with writing the same file)
log_dir = None if eval_env or no_log else self.save_path

# Special case for GoalEnvs: log success rate too
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):
self.monitor_kwargs = dict(info_keywords=("is_success",))

# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
spec = gym.spec(self.env_name.gym_id)

def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
env = make_vec_env(
make_env,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=self.monitor_kwargs,
)
if self.use_envpool:
if self.env_wrapper is not None:
warnings.warn("EnvPool does not support env wrappers, it will be ignored.")
if self.env_kwargs:
warnings.warn(
"EnvPool does not support env_kwargs, it will be ignored. "
"To pass keyword argument to envpool, use vec_env_kwargs instead."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have vec_env_kwargs available in the cli, but we should add it.

)

# Convert Atari game names
# See https://github.com/sail-sg/envpool/issues/14
env_id = self.env_name.gym_id
if self._is_atari and "NoFrameskip-v4" in env_id:
env_id = env_id.split("NoFrameskip-v4")[0] + "-v5"

env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs) # type: EnvPool
env.spec.id = self.env_name.gym_id
env = EnvPoolAdapter(env)
filename = None if log_dir is None else f"{log_dir}/monitor.csv"
env = VecMonitor(env, filename, **self.monitor_kwargs)

else:
# Special case for GoalEnvs: log success rate too
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):
if (
self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
):

"Neck" in self.env_name.gym_id was for personal use and should be removed.

self.monitor_kwargs = dict(info_keywords=("is_success",))

# Define make_env here so it works with subprocesses
# when the registry was modified with `--gym-packages`
# See https://github.com/HumanCompatibleAI/imitation/pull/160
spec = gym.spec(self.env_name.gym_id)

def make_env(**kwargs) -> gym.Env:
env = spec.make(**kwargs)
return env

# On most env, SubprocVecEnv does not help and is quite memory hungry
# therefore we use DummyVecEnv by default
env = make_vec_env(
make_env,
n_envs=n_envs,
seed=self.seed,
env_kwargs=self.env_kwargs,
monitor_dir=log_dir,
wrapper_class=self.env_wrapper,
vec_env_cls=self.vec_env_class,
vec_env_kwargs=self.vec_env_kwargs,
monitor_kwargs=self.monitor_kwargs,
)

if self.vec_env_wrapper is not None:
env = self.vec_env_wrapper(env)
Expand Down
10 changes: 9 additions & 1 deletion rl_zoo3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,13 @@ def train() -> None:
default=False,
help="if toggled, display a progress bar using tqdm and rich",
)
parser.add_argument(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing: a section in the readme, I will try to create a real doc soon (the readme is way too long now).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc is there

"-envpool",
"--use-envpool",
action="store_true",
default=False,
help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.",
)
parser.add_argument(
"-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123"
)
Expand Down Expand Up @@ -183,7 +190,7 @@ def train() -> None:
uuid_str = f"_{uuid.uuid4()}" if args.uuid else ""
if args.seed < 0:
# Seed but with a random one
args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined]
args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined]

set_random_seed(args.seed)

Expand Down Expand Up @@ -259,6 +266,7 @@ def train() -> None:
device=args.device,
config=args.conf_file,
show_progress=args.progress,
use_envpool=args.use_envpool,
)

# Prepare experiment and launch hyperparameter optimization if needed
Expand Down
92 changes: 92 additions & 0 deletions rl_zoo3/vec_env_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Any, Iterable, List, Optional, Sequence, Type, Union

import gym
import numpy as np
from envpool.python.protocol import EnvPool
from gym import spaces
from stable_baselines3.common.vec_env import VecEnv
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn

# Used when we want to access one or more VecEnv
VecEnvIndices = Union[None, int, Iterable[int]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't have that type hint defined in SB3?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, in base_with_env. I don't remember why I redefined it here. I probably should have imported it



def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space:
"""
Convert the dtype of a space to float32.

:param space: Space to convert
:return: Converted space
"""
if isinstance(space, spaces.Box):
if space.dtype == np.float64:
space.dtype = np.dtype(np.float32)
elif isinstance(space, spaces.Dict):
for key, sub_space in space.spaces.items():
space.spaces[key] = _convert_dtype_to_float32(sub_space)
return space


class EnvPoolAdapter(VecEnv):
"""
Wrapper for EnvPool to make it compatible with Stable-Baselines3.

:param venv: EnvPool environment
"""

def __init__(self, venv: EnvPool) -> None:
self.venv = venv
action_space = venv.action_space
observation_space = venv.observation_space
# Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145
observation_space = _convert_dtype_to_float32(observation_space)
action_space = _convert_dtype_to_float32(action_space)

super().__init__(
num_envs=venv.spec.config.num_envs, # Retrieve the number of environments from the config
observation_space=observation_space,
action_space=action_space,
)

def reset(self) -> VecEnvObs:
return self.venv.reset()

def step_async(self, actions: np.ndarray) -> None:
self.actions = actions

def step_wait(self) -> VecEnvStepReturn:
obs, rewards, dones, info_dict = self.venv.step(self.actions)
infos = []
# Convert dict to list of dict and add terminal observation
for env_idx in range(self.num_envs):
infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)})
if dones[env_idx]:
infos[env_idx]["terminal_observation"] = obs[env_idx]
obs[env_idx] = self.venv.reset(np.array([env_idx]))

return obs, rewards, dones, infos

def close(self) -> None:
return None # No closing method in envpool

def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]:
raise NotImplementedError("EnvPool does not support get_attr()")

def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None:
raise NotImplementedError("EnvPool does not support set_attr()")

def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]:
raise NotImplementedError("EnvPool does not support env_method()")

def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]:
return [False for _ in range(self.num_envs)]

def get_images(self) -> Sequence[np.ndarray]:
raise NotImplementedError("EnvPool does not support get_images()")

def render(self, mode: str = "human") -> Optional[np.ndarray]:
raise NotImplementedError("EnvPool does not support render()")

def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed
# You can only seed EnvPool env by calling envpool.make()
return [None for _ in range(self.num_envs)]