-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnvPool support (as VecEnv subclass) #355
base: master
Are you sure you want to change the base?
Changes from 22 commits
ad54b99
0171295
594ddca
36c30fd
d8bf984
35b32c0
6a94d26
1bca590
e457d3d
73d6bce
aec9529
a452cad
9dc8a35
ac532e8
6b0620f
0a992f1
c5f38e1
94f1199
16cec76
4265e9f
a63f849
3961b17
0946fc1
c94924e
a7bf243
6f161f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,3 +18,4 @@ seaborn | |
tqdm | ||
rich | ||
importlib-metadata~=4.13 # flake8 not compatible with importlib-metadata>5.0 | ||
envpool |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -37,6 +37,7 @@ | |||||||||||||||||||||||
SubprocVecEnv, | ||||||||||||||||||||||||
VecEnv, | ||||||||||||||||||||||||
VecFrameStack, | ||||||||||||||||||||||||
VecMonitor, | ||||||||||||||||||||||||
VecNormalize, | ||||||||||||||||||||||||
VecTransposeImage, | ||||||||||||||||||||||||
is_vecenv_wrapped, | ||||||||||||||||||||||||
|
@@ -51,6 +52,15 @@ | |||||||||||||||||||||||
from rl_zoo3.hyperparams_opt import HYPERPARAMS_SAMPLER | ||||||||||||||||||||||||
from rl_zoo3.utils import ALGOS, get_callback_list, get_class_by_name, get_latest_run_id, get_wrapper_class, linear_schedule | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
try: | ||||||||||||||||||||||||
import envpool | ||||||||||||||||||||||||
except ImportError: | ||||||||||||||||||||||||
envpool = None | ||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
from envpool.python.protocol import EnvPool | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
from rl_zoo3.vec_env_wrappers import EnvPoolAdapter | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
||||||||||||||||||||||||
class ExperimentManager: | ||||||||||||||||||||||||
""" | ||||||||||||||||||||||||
|
@@ -97,6 +107,7 @@ def __init__( | |||||||||||||||||||||||
device: Union[th.device, str] = "auto", | ||||||||||||||||||||||||
config: Optional[str] = None, | ||||||||||||||||||||||||
show_progress: bool = False, | ||||||||||||||||||||||||
use_envpool: bool = False, | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
super().__init__() | ||||||||||||||||||||||||
self.algo = algo | ||||||||||||||||||||||||
|
@@ -120,6 +131,7 @@ def __init__( | |||||||||||||||||||||||
self.seed = seed | ||||||||||||||||||||||||
self.optimization_log_path = optimization_log_path | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
self.use_envpool = use_envpool | ||||||||||||||||||||||||
self.vec_env_class = {"dummy": DummyVecEnv, "subproc": SubprocVecEnv}[vec_env_type] | ||||||||||||||||||||||||
self.vec_env_wrapper = None | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -581,37 +593,59 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) | |||||||||||||||||||||||
# Do not log eval env (issue with writing the same file) | ||||||||||||||||||||||||
log_dir = None if eval_env or no_log else self.save_path | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Special case for GoalEnvs: log success rate too | ||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||
"Neck" in self.env_name.gym_id | ||||||||||||||||||||||||
or self.is_robotics_env(self.env_name.gym_id) | ||||||||||||||||||||||||
or "parking-v0" in self.env_name.gym_id | ||||||||||||||||||||||||
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
self.monitor_kwargs = dict(info_keywords=("is_success",)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Define make_env here so it works with subprocesses | ||||||||||||||||||||||||
# when the registry was modified with `--gym-packages` | ||||||||||||||||||||||||
# See https://github.com/HumanCompatibleAI/imitation/pull/160 | ||||||||||||||||||||||||
spec = gym.spec(self.env_name.gym_id) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def make_env(**kwargs) -> gym.Env: | ||||||||||||||||||||||||
env = spec.make(**kwargs) | ||||||||||||||||||||||||
return env | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# On most env, SubprocVecEnv does not help and is quite memory hungry | ||||||||||||||||||||||||
# therefore we use DummyVecEnv by default | ||||||||||||||||||||||||
env = make_vec_env( | ||||||||||||||||||||||||
make_env, | ||||||||||||||||||||||||
n_envs=n_envs, | ||||||||||||||||||||||||
seed=self.seed, | ||||||||||||||||||||||||
env_kwargs=self.env_kwargs, | ||||||||||||||||||||||||
monitor_dir=log_dir, | ||||||||||||||||||||||||
wrapper_class=self.env_wrapper, | ||||||||||||||||||||||||
vec_env_cls=self.vec_env_class, | ||||||||||||||||||||||||
vec_env_kwargs=self.vec_env_kwargs, | ||||||||||||||||||||||||
monitor_kwargs=self.monitor_kwargs, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
if self.use_envpool: | ||||||||||||||||||||||||
if self.env_wrapper is not None: | ||||||||||||||||||||||||
warnings.warn("EnvPool does not support env wrappers, it will be ignored.") | ||||||||||||||||||||||||
if self.env_kwargs: | ||||||||||||||||||||||||
warnings.warn( | ||||||||||||||||||||||||
"EnvPool does not support env_kwargs, it will be ignored. " | ||||||||||||||||||||||||
"To pass keyword argument to envpool, use vec_env_kwargs instead." | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Convert Atari game names | ||||||||||||||||||||||||
# See https://github.com/sail-sg/envpool/issues/14 | ||||||||||||||||||||||||
env_id = self.env_name.gym_id | ||||||||||||||||||||||||
if self._is_atari and "NoFrameskip-v4" in env_id: | ||||||||||||||||||||||||
env_id = env_id.split("NoFrameskip-v4")[0] + "-v5" | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
env = envpool.make(env_id, env_type="gym", num_envs=n_envs, seed=self.seed, **self.vec_env_kwargs) # type: EnvPool | ||||||||||||||||||||||||
env.spec.id = self.env_name.gym_id | ||||||||||||||||||||||||
env = EnvPoolAdapter(env) | ||||||||||||||||||||||||
filename = None if log_dir is None else f"{log_dir}/monitor.csv" | ||||||||||||||||||||||||
env = VecMonitor(env, filename, **self.monitor_kwargs) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
else: | ||||||||||||||||||||||||
# Special case for GoalEnvs: log success rate too | ||||||||||||||||||||||||
if ( | ||||||||||||||||||||||||
"Neck" in self.env_name.gym_id | ||||||||||||||||||||||||
or self.is_robotics_env(self.env_name.gym_id) | ||||||||||||||||||||||||
or "parking-v0" in self.env_name.gym_id | ||||||||||||||||||||||||
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs | ||||||||||||||||||||||||
): | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||
self.monitor_kwargs = dict(info_keywords=("is_success",)) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# Define make_env here so it works with subprocesses | ||||||||||||||||||||||||
# when the registry was modified with `--gym-packages` | ||||||||||||||||||||||||
# See https://github.com/HumanCompatibleAI/imitation/pull/160 | ||||||||||||||||||||||||
spec = gym.spec(self.env_name.gym_id) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
def make_env(**kwargs) -> gym.Env: | ||||||||||||||||||||||||
env = spec.make(**kwargs) | ||||||||||||||||||||||||
return env | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
# On most env, SubprocVecEnv does not help and is quite memory hungry | ||||||||||||||||||||||||
# therefore we use DummyVecEnv by default | ||||||||||||||||||||||||
env = make_vec_env( | ||||||||||||||||||||||||
make_env, | ||||||||||||||||||||||||
n_envs=n_envs, | ||||||||||||||||||||||||
seed=self.seed, | ||||||||||||||||||||||||
env_kwargs=self.env_kwargs, | ||||||||||||||||||||||||
monitor_dir=log_dir, | ||||||||||||||||||||||||
wrapper_class=self.env_wrapper, | ||||||||||||||||||||||||
vec_env_cls=self.vec_env_class, | ||||||||||||||||||||||||
vec_env_kwargs=self.vec_env_kwargs, | ||||||||||||||||||||||||
monitor_kwargs=self.monitor_kwargs, | ||||||||||||||||||||||||
) | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
if self.vec_env_wrapper is not None: | ||||||||||||||||||||||||
env = self.vec_env_wrapper(env) | ||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,13 @@ def train() -> None: | |
default=False, | ||
help="if toggled, display a progress bar using tqdm and rich", | ||
) | ||
parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing: a section in the readme, I will try to create a real doc soon (the readme is way too long now). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc is there |
||
"-envpool", | ||
"--use-envpool", | ||
action="store_true", | ||
default=False, | ||
help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.", | ||
) | ||
parser.add_argument( | ||
"-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123" | ||
) | ||
|
@@ -183,7 +190,7 @@ def train() -> None: | |
uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" | ||
if args.seed < 0: | ||
# Seed but with a random one | ||
args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined] | ||
args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined] | ||
|
||
set_random_seed(args.seed) | ||
|
||
|
@@ -259,6 +266,7 @@ def train() -> None: | |
device=args.device, | ||
config=args.conf_file, | ||
show_progress=args.progress, | ||
use_envpool=args.use_envpool, | ||
) | ||
|
||
# Prepare experiment and launch hyperparameter optimization if needed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from typing import Any, Iterable, List, Optional, Sequence, Type, Union | ||
|
||
import gym | ||
import numpy as np | ||
from envpool.python.protocol import EnvPool | ||
from gym import spaces | ||
from stable_baselines3.common.vec_env import VecEnv | ||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn | ||
|
||
# Used when we want to access one or more VecEnv | ||
VecEnvIndices = Union[None, int, Iterable[int]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't have that type hint defined in SB3? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in |
||
|
||
|
||
def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space: | ||
""" | ||
Convert the dtype of a space to float32. | ||
|
||
:param space: Space to convert | ||
:return: Converted space | ||
""" | ||
if isinstance(space, spaces.Box): | ||
if space.dtype == np.float64: | ||
space.dtype = np.dtype(np.float32) | ||
elif isinstance(space, spaces.Dict): | ||
for key, sub_space in space.spaces.items(): | ||
space.spaces[key] = _convert_dtype_to_float32(sub_space) | ||
return space | ||
|
||
|
||
class EnvPoolAdapter(VecEnv): | ||
""" | ||
Wrapper for EnvPool to make it compatible with Stable-Baselines3. | ||
|
||
:param venv: EnvPool environment | ||
""" | ||
|
||
def __init__(self, venv: EnvPool) -> None: | ||
self.venv = venv | ||
action_space = venv.action_space | ||
observation_space = venv.observation_space | ||
# Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 | ||
observation_space = _convert_dtype_to_float32(observation_space) | ||
action_space = _convert_dtype_to_float32(action_space) | ||
|
||
super().__init__( | ||
num_envs=venv.spec.config.num_envs, # Retrieve the number of environments from the config | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
) | ||
|
||
def reset(self) -> VecEnvObs: | ||
return self.venv.reset() | ||
|
||
def step_async(self, actions: np.ndarray) -> None: | ||
self.actions = actions | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, rewards, dones, info_dict = self.venv.step(self.actions) | ||
infos = [] | ||
# Convert dict to list of dict and add terminal observation | ||
for env_idx in range(self.num_envs): | ||
infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) | ||
if dones[env_idx]: | ||
infos[env_idx]["terminal_observation"] = obs[env_idx] | ||
obs[env_idx] = self.venv.reset(np.array([env_idx])) | ||
|
||
return obs, rewards, dones, infos | ||
|
||
def close(self) -> None: | ||
return None # No closing method in envpool | ||
|
||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: | ||
raise NotImplementedError("EnvPool does not support get_attr()") | ||
|
||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: | ||
raise NotImplementedError("EnvPool does not support set_attr()") | ||
|
||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: | ||
raise NotImplementedError("EnvPool does not support env_method()") | ||
|
||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: | ||
return [False for _ in range(self.num_envs)] | ||
|
||
def get_images(self) -> Sequence[np.ndarray]: | ||
raise NotImplementedError("EnvPool does not support get_images()") | ||
|
||
def render(self, mode: str = "human") -> Optional[np.ndarray]: | ||
raise NotImplementedError("EnvPool does not support render()") | ||
|
||
def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed | ||
# You can only seed EnvPool env by calling envpool.make() | ||
return [None for _ in range(self.num_envs)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have
vec_env_kwargs
available in the cli, but we should add it.