diff --git a/CHANGELOG.md b/CHANGELOG.md index acdf4ed5e..450ad3faa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## Release 2.5.0a0 (WIP) +## Release 2.5.0a1 (WIP) ### Breaking Changes - Upgraded to Pytorch >= 2.3.0 @@ -6,6 +6,7 @@ ### New Features - Added support for Numpy v2 +- Added support for specifying callbacks and env wrapper as python object in python config files (instead of string) ### Bug fixes diff --git a/rl_zoo3/enjoy.py b/rl_zoo3/enjoy.py index 86225650a..082340aae 100644 --- a/rl_zoo3/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -204,7 +204,7 @@ def enjoy() -> None: # noqa: C901 obs = env.reset() # Deterministic by default except for atari games - stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic + stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic) deterministic = not stochastic episode_reward = 0.0 diff --git a/rl_zoo3/exp_manager.py b/rl_zoo3/exp_manager.py index 321a0378a..b43c16033 100644 --- a/rl_zoo3/exp_manager.py +++ b/rl_zoo3/exp_manager.py @@ -602,8 +602,7 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False) if ( "Neck" in self.env_name.gym_id or self.is_robotics_env(self.env_name.gym_id) - or "parking-v0" in self.env_name.gym_id - and len(self.monitor_kwargs) == 0 # do not overwrite custom kwargs + or ("parking-v0" in self.env_name.gym_id and len(self.monitor_kwargs) == 0) # do not overwrite custom kwargs ): self.monitor_kwargs = dict(info_keywords=("is_success",)) diff --git a/rl_zoo3/push_to_hub.py b/rl_zoo3/push_to_hub.py index 499bcc366..5f10cc56e 100644 --- a/rl_zoo3/push_to_hub.py +++ b/rl_zoo3/push_to_hub.py @@ -398,7 +398,7 @@ def package_to_hub( model = ALGOS[algo].load(model_path, env=eval_env, custom_objects=custom_objects, device=args.device, **kwargs) # Deterministic by default except for atari games - stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic + stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic) deterministic = not stochastic # Default model name, the model will be saved under "{algo}-{env_name}.zip" diff --git a/rl_zoo3/record_video.py b/rl_zoo3/record_video.py index a2b2071d4..ab9ecae65 100644 --- a/rl_zoo3/record_video.py +++ b/rl_zoo3/record_video.py @@ -131,7 +131,7 @@ model = ALGOS[algo].load(model_path, env=env, custom_objects=custom_objects, **kwargs) # Deterministic by default except for atari games - stochastic = args.stochastic or (is_atari or is_minigrid) and not args.deterministic + stochastic = args.stochastic or ((is_atari or is_minigrid) and not args.deterministic) deterministic = not stochastic if video_folder is None: diff --git a/rl_zoo3/utils.py b/rl_zoo3/utils.py index 30d557945..961e7fffa 100644 --- a/rl_zoo3/utils.py +++ b/rl_zoo3/utils.py @@ -99,16 +99,22 @@ def get_class_name(wrapper_name): kwargs = wrapper_dict[wrapper_name] else: kwargs = {} - wrapper_module = importlib.import_module(get_module_name(wrapper_name)) - wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name)) + + if isinstance(wrapper_name, str): + wrapper_module = importlib.import_module(get_module_name(wrapper_name)) + wrapper_class = getattr(wrapper_module, get_class_name(wrapper_name)) + elif isinstance(wrapper_name, type): + # No conversion needed + wrapper_class = wrapper_name + else: + raise ValueError( + f"Unexpected value {wrapper_name} for a {key}, must a str and a class, not {type(wrapper_name)}" + ) + wrapper_classes.append(wrapper_class) wrapper_kwargs.append(kwargs) def wrap_env(env: gym.Env) -> gym.Env: - """ - :param env: - :return: - """ for wrapper_class, kwargs in zip(wrapper_classes, wrapper_kwargs): env = wrapper_class(env, **kwargs) return env @@ -183,8 +189,12 @@ def get_callback_list(hyperparams: dict[str, Any]) -> list[BaseCallback]: else: kwargs = {} - callback_class = get_class_by_name(callback_name) - callbacks.append(callback_class(**kwargs)) + if isinstance(callback_name, BaseCallback): + # No conversion needed + callbacks.append(callback_name) + else: + callback_class = get_class_by_name(callback_name) + callbacks.append(callback_class(**kwargs)) return callbacks diff --git a/rl_zoo3/version.txt b/rl_zoo3/version.txt index b8feefb94..240183827 100644 --- a/rl_zoo3/version.txt +++ b/rl_zoo3/version.txt @@ -1 +1 @@ -2.5.0a0 +2.5.0a1 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 8f18fbaec..d09341026 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -1,6 +1,11 @@ import shlex import subprocess +import pytest +import stable_baselines3 as sb3 + +from rl_zoo3.utils import get_callback_list + def _assert_eq(left, right): assert left == right, f"{left} != {right}" @@ -13,3 +18,26 @@ def test_raw_stat_callback(tmp_path): ) return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0) + + +@pytest.mark.parametrize( + "callback", + [ + None, + "rl_zoo3.callbacks.RawStatisticsCallback", + [ + {"stable_baselines3.common.callbacks.StopTrainingOnMaxEpisodes": dict(max_episodes=3)}, + "rl_zoo3.callbacks.RawStatisticsCallback", + ], + [sb3.common.callbacks.StopTrainingOnMaxEpisodes(3)], + ], +) +def test_get_callback(callback): + hyperparams = {"callback": callback} + callback_list = get_callback_list(hyperparams) + if callback is None: + assert len(callback_list) == 0 + elif isinstance(callback, str): + assert len(callback_list) == 1 + else: + assert len(callback_list) == len(callback) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 6b06f765f..943acf9e8 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -1,10 +1,12 @@ import gymnasium as gym import pytest +import stable_baselines3 as sb3 from stable_baselines3 import A2C from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import DummyVecEnv -import rl_zoo3.import_envs # noqa: F401 +import rl_zoo3.import_envs +import rl_zoo3.wrappers from rl_zoo3.utils import get_wrapper_class from rl_zoo3.wrappers import ActionNoiseWrapper, DelayedRewardWrapper, HistoryWrapper, TimeFeatureWrapper @@ -24,6 +26,7 @@ def test_wrappers(): None, {"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=2)}, [{"rl_zoo3.wrappers.HistoryWrapper": dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"], + [{rl_zoo3.wrappers.HistoryWrapper: dict(horizon=3)}, "rl_zoo3.wrappers.TimeFeatureWrapper"], ], ) def test_get_wrapper(env_wrapper): @@ -40,6 +43,7 @@ def test_get_wrapper(env_wrapper): [ None, {"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=2)}, + {sb3.common.vec_env.VecFrameStack: dict(n_stack=2)}, [{"stable_baselines3.common.vec_env.VecFrameStack": dict(n_stack=3)}, "stable_baselines3.common.vec_env.VecMonitor"], ], )