Skip to content

Commit

Permalink
Add support for python object in python config for wrapper/callbacks (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin authored Dec 6, 2024
1 parent 633954f commit 506bb7a
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 16 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
## Release 2.5.0a0 (WIP)
## Release 2.5.0a1 (WIP)

### Breaking Changes
- Upgraded to Pytorch >= 2.3.0
- Upgraded to SB3 >= 2.5.0

### New Features
- Added support for Numpy v2
- Added support for specifying callbacks and env wrapper as python object in python config files (instead of string)

### Bug fixes

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def enjoy() -> None: # noqa: C901
obs = env.reset()

# Deterministic by default except for atari games
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
deterministic = not stochastic

episode_reward = 0.0
Expand Down
3 changes: 1 addition & 2 deletions rl_zoo3/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,7 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
if (
"Neck" in self.env_name.gym_id
or self.is_robotics_env(self.env_name.gym_id)
or "parking-v0" in self.env_name.gym_id
and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs
or ("parking-v0" in self.env_name.gym_id and len(self.monitor_kwargs) == 0) # do not overwrite custom kwargs
):
self.monitor_kwargs = dict(info_keywords=("is_success",))

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/push_to_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def package_to_hub(
model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs)

# Deterministic by default except for atari games
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
deterministic = not stochastic

# Default model name, the model will be saved under "{algo}-{env_name}.zip"
Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/record_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@
model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs)

# Deterministic by default except for atari games
stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic
stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic)
deterministic = not stochastic

if video_folder is None:
Expand Down
26 changes: 18 additions & 8 deletions rl_zoo3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,16 +99,22 @@ def get_class_name(wrapper_name):
kwargs = wrapper_dict[wrapper_name]
else:
kwargs = {}
wrapper_module = importlib.import_module(get_module_name(wrapper_name))
wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name))

if isinstance(wrapper_name, str):
wrapper_module = importlib.import_module(get_module_name(wrapper_name))
wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name))
elif isinstance(wrapper_name, type):
# No conversion needed
wrapper_class = wrapper_name
else:
raise ValueError(
f"Unexpected value {wrapper_name} for a {key}, must a str and a class, not {type(wrapper_name)}"
)

wrapper_classes.append(wrapper_class)
wrapper_kwargs.append(kwargs)

def wrap_env(env: gym.Env) -> gym.Env:
"""
:param env:
:return:
"""
for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs):
env = wrapper_class(env, **kwargs)
return env
Expand Down Expand Up @@ -183,8 +189,12 @@ def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]:
else:
kwargs = {}

callback_class = get_class_by_name(callback_name)
callbacks.append(callback_class(**kwargs))
if isinstance(callback_name, BaseCallback):
# No conversion needed
callbacks.append(callback_name)
else:
callback_class = get_class_by_name(callback_name)
callbacks.append(callback_class(**kwargs))

return callbacks

Expand Down
2 changes: 1 addition & 1 deletion rl_zoo3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.5.0a0
2.5.0a1
28 changes: 28 additions & 0 deletions tests/test_callbacks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import shlex
import subprocess

import pytest
import stable_baselines3 as sb3

from rl_zoo3.utils import get_callback_list


def _assert_eq(left, right):
assert left == right, f"{left} != {right}"
Expand All @@ -13,3 +18,26 @@ def test_raw_stat_callback(tmp_path):
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)


@pytest.mark.parametrize(
"callback",
[
None,
"rl_zoo3.callbacks.RawStatisticsCallback",
[
{"stable_baselines3.common.callbacks.StopTrainingOnMaxEpisodes": dict(max_episodes=3)},
"rl_zoo3.callbacks.RawStatisticsCallback",
],
[sb3.common.callbacks.StopTrainingOnMaxEpisodes(3)],
],
)
def test_get_callback(callback):
hyperparams = {"callback": callback}
callback_list = get_callback_list(hyperparams)
if callback is None:
assert len(callback_list) == 0
elif isinstance(callback, str):
assert len(callback_list) == 1
else:
assert len(callback_list) == len(callback)
6 changes: 5 additions & 1 deletion tests/test_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import gymnasium as gym
import pytest
import stable_baselines3 as sb3
from stable_baselines3 import A2C
from stable_baselines3.common.env_checker import check_env
from stable_baselines3.common.env_util import DummyVecEnv

import rl_zoo3.import_envs # noqa: F401
import rl_zoo3.import_envs
import rl_zoo3.wrappers
from rl_zoo3.utils import get_wrapper_class
from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper

Expand All @@ -24,6 +26,7 @@ def test_wrappers():
None,
{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=2)},
[{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"],
[{rl_zoo3.wrappers.HistoryWrapper: dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"],
],
)
def test_get_wrapper(env_wrapper):
Expand All @@ -40,6 +43,7 @@ def test_get_wrapper(env_wrapper):
[
None,
{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=2)},
{sb3.common.vec_env.VecFrameStack: dict(n_stack=2)},
[{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=3)}, "stable_baselines3.common.vec_env.VecMonitor"],
],
)
Expand Down

0 comments on commit 506bb7a

Please sign in to comment.