From 76992081ad2eb0b27b19fa6d6c64915d885731ac Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 5 Nov 2024 17:59:54 +0100 Subject: [PATCH] Fix vecnormalize stats --- rl-trained-agents | 2 +- rl_zoo3/enjoy.py | 10 +++++++++- tests/test_train.py | 12 ++++++++++-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/rl-trained-agents b/rl-trained-agents index eb1bd43ec..cd35bde61 160000 --- a/rl-trained-agents +++ b/rl-trained-agents @@ -1 +1 @@ -Subproject commit eb1bd43ece6554de857d71ddd8ad8f5878958ff8 +Subproject commit cd35bde610f4045bf2e0731c8f4c88d22df8fc85 diff --git a/rl_zoo3/enjoy.py b/rl_zoo3/enjoy.py index cd7f08056..86225650a 100644 --- a/rl_zoo3/enjoy.py +++ b/rl_zoo3/enjoy.py @@ -184,7 +184,9 @@ def enjoy() -> None: # noqa: C901 "learning_rate": 0.0, "lr_schedule": lambda _: 0.0, "clip_range": lambda _: 0.0, - # "observation_space": env.observation_space, # load models with different obs bounds + # load models with different obs bounds + # Note: doesn't work with channel last envs + # "observation_space": env.observation_space, } if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""): @@ -193,6 +195,12 @@ def enjoy() -> None: # noqa: C901 model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs) # Uncomment to save patched file (for instance gym -> gymnasium) # model.save(model_path) + # Patch VecNormalize (gym -> gymnasium) + # from pathlib import Path + # env.observation_space = model.observation_space + # env.action_space = model.action_space + # env.save(Path(model_path).parent / env_name / "vecnormalize.pkl") + obs = env.reset() # Deterministic by default except for atari games diff --git a/tests/test_train.py b/tests/test_train.py index d0780acc1..0894cf669 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -1,6 +1,7 @@ import os import shlex import subprocess +from importlib.metadata import version import pytest @@ -36,10 +37,17 @@ def test_train(tmp_path, experiment): def test_continue_training(tmp_path): - algo, env_id = "a2c", "CartPole-v1" + algo = "a2c" + if version("gymnasium") > "0.29.1": + # See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341 + # obs bounds have changed... + env_id = "CartPole-v1" + else: + env_id = "Pendulum-v1" + cmd = ( f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} " - "-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip" + f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip" ) return_code = subprocess.call(shlex.split(cmd)) _assert_eq(return_code, 0)