-
Notifications
You must be signed in to change notification settings - Fork 526
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add EnvPool support (as VecEnv subclass) #355
base: master
Are you sure you want to change the base?
Changes from all commits
ad54b99
0171295
594ddca
36c30fd
d8bf984
35b32c0
6a94d26
1bca590
e457d3d
73d6bce
aec9529
a452cad
9dc8a35
ac532e8
6b0620f
0a992f1
c5f38e1
94f1199
16cec76
4265e9f
a63f849
3961b17
0946fc1
c94924e
a7bf243
6f161f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,5 +19,6 @@ huggingface_sb3>=3.0,<4.0 | |
seaborn | ||
tqdm | ||
rich | ||
envpool | ||
moviepy | ||
ruff |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,6 +153,13 @@ def train() -> None: | |
default=False, | ||
help="if toggled, display a progress bar using tqdm and rich", | ||
) | ||
parser.add_argument( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing: a section in the readme, I will try to create a real doc soon (the readme is way too long now). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc is there |
||
"-envpool", | ||
"--use-envpool", | ||
action="store_true", | ||
default=False, | ||
help="if toggled, try to use EnvPool to run the env, env_wrappers are not supported.", | ||
) | ||
parser.add_argument( | ||
"-tags", "--wandb-tags", type=str, default=[], nargs="+", help="Tags for wandb run, e.g.: -tags optimized pr-123" | ||
) | ||
|
@@ -178,7 +185,7 @@ def train() -> None: | |
uuid_str = f"_{uuid.uuid4()}" if args.uuid else "" | ||
if args.seed < 0: | ||
# Seed but with a random one | ||
args.seed = np.random.randint(2**32 - 1, dtype="int64").item() # type: ignore[attr-defined] | ||
args.seed = np.random.randint(2**31 - 1, dtype="int64").item() # type: ignore[attr-defined] | ||
|
||
set_random_seed(args.seed) | ||
|
||
|
@@ -255,6 +262,7 @@ def train() -> None: | |
device=args.device, | ||
config=args.conf_file, | ||
show_progress=args.progress, | ||
use_envpool=args.use_envpool, | ||
) | ||
|
||
# Prepare experiment and launch hyperparameter optimization if needed | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Any, Iterable, List, Optional, Sequence, Type, Union | ||
|
||
import gymnasium as gym | ||
import numpy as np | ||
from envpool.python.protocol import EnvPool | ||
from gymnasium import spaces | ||
from stable_baselines3.common.vec_env import VecEnv | ||
from stable_baselines3.common.vec_env.base_vec_env import VecEnvObs, VecEnvStepReturn | ||
|
||
# Used when we want to access one or more VecEnv | ||
VecEnvIndices = Union[None, int, Iterable[int]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't have that type hint defined in SB3? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, in |
||
|
||
|
||
def _convert_dtype_to_float32(space: spaces.Space) -> spaces.Space: | ||
""" | ||
Convert the dtype of a space to float32. | ||
|
||
:param space: Space to convert | ||
:return: Converted space | ||
""" | ||
if isinstance(space, spaces.Box): | ||
if space.dtype == np.float64: | ||
space.dtype = np.dtype(np.float32) | ||
elif isinstance(space, spaces.Dict): | ||
for key, sub_space in space.spaces.items(): | ||
space.spaces[key] = _convert_dtype_to_float32(sub_space) | ||
return space | ||
|
||
|
||
class EnvPoolAdapter(VecEnv): | ||
""" | ||
Wrapper for EnvPool to make it compatible with Stable-Baselines3. | ||
|
||
:param venv: EnvPool environment | ||
""" | ||
|
||
def __init__(self, venv: EnvPool) -> None: | ||
self.venv = venv | ||
action_space = venv.action_space | ||
observation_space = venv.observation_space | ||
# Tmp fix for https://github.com/DLR-RM/stable-baselines3/issues/1145 | ||
observation_space = _convert_dtype_to_float32(observation_space) | ||
action_space = _convert_dtype_to_float32(action_space) | ||
|
||
super().__init__( | ||
num_envs=venv.spec.config.num_envs, # Retrieve the number of environments from the config | ||
observation_space=observation_space, | ||
action_space=action_space, | ||
) | ||
|
||
def reset(self) -> VecEnvObs: | ||
obs, reset_infos = self.venv.reset() | ||
for key, value in reset_infos.items(): | ||
if key == "players": # only used for multi-agent setting | ||
continue | ||
for env_idx in range(self.num_envs): | ||
self.reset_infos[env_idx][key] = value[env_idx] | ||
return obs | ||
|
||
def step_async(self, actions: np.ndarray) -> None: | ||
self.actions = actions | ||
|
||
def step_wait(self) -> VecEnvStepReturn: | ||
obs, rewards, terminated, truncated, info_dict = self.venv.step(self.actions) | ||
dones = terminated | truncated | ||
|
||
infos = [] | ||
# Convert dict to list of dict and add terminal observation | ||
for env_idx in range(self.num_envs): | ||
infos.append({key: info_dict[key][env_idx] for key in info_dict.keys() if isinstance(info_dict[key], np.ndarray)}) | ||
infos[env_idx]["TimeLimit.truncated"] = truncated[env_idx] and not terminated[env_idx] | ||
if dones[env_idx]: | ||
infos[env_idx]["terminal_observation"] = obs[env_idx] | ||
obs[env_idx], reset_infos = self.venv.reset(np.array([env_idx])) | ||
# Store reset_infos | ||
for key, value in reset_infos.items(): | ||
if key == "players": | ||
continue | ||
self.reset_infos[env_idx][key] = value[0] | ||
return obs, rewards, dones, infos | ||
|
||
def close(self) -> None: | ||
return None # No closing method in envpool | ||
|
||
def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: | ||
raise NotImplementedError("EnvPool does not support get_attr()") | ||
|
||
def set_attr(self, attr_name: str, value: Any, indices: VecEnvIndices = None) -> None: | ||
raise NotImplementedError("EnvPool does not support set_attr()") | ||
|
||
def env_method(self, method_name: str, *method_args, indices: VecEnvIndices = None, **method_kwargs) -> List[Any]: | ||
raise NotImplementedError("EnvPool does not support env_method()") | ||
|
||
def env_is_wrapped(self, wrapper_class: Type[gym.Wrapper], indices: VecEnvIndices = None) -> List[bool]: | ||
return [False for _ in range(self.num_envs)] | ||
|
||
def get_images(self) -> Sequence[np.ndarray]: | ||
raise NotImplementedError("EnvPool does not support get_images()") | ||
|
||
def render(self, mode: str = "human") -> Optional[np.ndarray]: | ||
raise NotImplementedError("EnvPool does not support render()") | ||
|
||
def seed(self, seed: Optional[int] = None) -> Sequence[None]: # type: ignore[override] # until SB3/#1318 is closed | ||
# You can only seed EnvPool env by calling envpool.make() | ||
return [None for _ in range(self.num_envs)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we have
vec_env_kwargs
available in the cli, but we should add it.