Skip to content

Commit

Permalink
Fix vecnormalize stats
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Nov 5, 2024
1 parent bc3514e commit 7699208
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 4 deletions.
10 changes: 9 additions & 1 deletion rl_zoo3/enjoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,9 @@ def enjoy() -> None: # noqa: C901
"learning_rate": 0.0,
"lr_schedule": lambda _: 0.0,
"clip_range": lambda _: 0.0,
# "observation_space": env.observation_space, # load models with different obs bounds
# load models with different obs bounds
# Note: doesn't work with channel last envs
# "observation_space": env.observation_space,
}

if "HerReplayBuffer" in hyperparams.get("replay_buffer_class", ""):
Expand All @@ -193,6 +195,12 @@ def enjoy() -> None: # noqa: C901
model = ALGOS[algo].load(model_path, custom_objects=custom_objects, device=args.device, **kwargs)
# Uncomment to save patched file (for instance gym -> gymnasium)
# model.save(model_path)
# Patch VecNormalize (gym -> gymnasium)
# from pathlib import Path
# env.observation_space = model.observation_space
# env.action_space = model.action_space
# env.save(Path(model_path).parent / env_name / "vecnormalize.pkl")

obs = env.reset()

# Deterministic by default except for atari games
Expand Down
12 changes: 10 additions & 2 deletions tests/test_train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shlex
import subprocess
from importlib.metadata import version

import pytest

Expand Down Expand Up @@ -36,10 +37,17 @@ def test_train(tmp_path, experiment):


def test_continue_training(tmp_path):
algo, env_id = "a2c", "CartPole-v1"
algo = "a2c"
if version("gymnasium") > "0.29.1":
# See https://github.com/DLR-RM/stable-baselines3/pull/1837#issuecomment-2457322341
# obs bounds have changed...
env_id = "CartPole-v1"
else:
env_id = "Pendulum-v1"

cmd = (
f"python train.py -n {N_STEPS} --algo {algo} --env {env_id} --log-folder {tmp_path} "
"-i rl-trained-agents/a2c/CartPole-v1_1/CartPole-v1.zip"
f"-i rl-trained-agents/a2c/{env_id}_1/{env_id}.zip"
)
return_code = subprocess.call(shlex.split(cmd))
_assert_eq(return_code, 0)
Expand Down

0 comments on commit 7699208

Please sign in to comment.