Skip to content

Commit 9bbabc1

Browse files
araffinQuentin18
andauthored
Add --eval-env-kwargs to train.py (#406)
* Add `--eval-env-kwargs` to `train.py` * Fix style * Fix default value for eval_env_kwargs * Update CHANGELOG.md * Simplify tests using shlex * Use shlex in most tests * Add run test for env kwargs * Fix eval env kwargs defaults * Fix test * Replace last tests --------- Co-authored-by: Quentin18 <[email protected]>
1 parent a6810f1 commit 9bbabc1

8 files changed

+126
-267
lines changed

CHANGELOG.md

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
## Release 2.2.0a2 (WIP)
1+
## Release 2.2.0a4 (WIP)
22

33
### Breaking Changes
44
- Removed `gym` dependency, the package is still required for some pretrained agents.
55

66
### New Features
7+
- Add `--eval-env-kwargs` to `train.py` (@Quentin18)
78

89
### Bug fixes
910

@@ -13,6 +14,7 @@
1314
- Updated docker image, removed support for X server
1415
- Replaced deprecated `optuna.suggest_uniform(...)` by `optuna.suggest_float(..., low=..., high=...)`
1516
- Switched to ruff for sorting imports
17+
- Updated tests to use `shlex.split()`
1618

1719
## Release 2.1.0 (2023-08-17)
1820

rl_zoo3/exp_manager.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def __init__(
7373
save_freq: int = -1,
7474
hyperparams: Optional[Dict[str, Any]] = None,
7575
env_kwargs: Optional[Dict[str, Any]] = None,
76+
eval_env_kwargs: Optional[Dict[str, Any]] = None,
7677
trained_agent: str = "",
7778
optimize_hyperparameters: bool = False,
7879
storage: Optional[str] = None,
@@ -111,7 +112,7 @@ def __init__(
111112
default_path = Path(__file__).parent.parent
112113

113114
self.config = config or str(default_path / f"hyperparams/{self.algo}.yml")
114-
self.env_kwargs: Dict[str, Any] = {} if env_kwargs is None else env_kwargs
115+
self.env_kwargs: Dict[str, Any] = env_kwargs or {}
115116
self.n_timesteps = n_timesteps
116117
self.normalize = False
117118
self.normalize_kwargs: Dict[str, Any] = {}
@@ -129,6 +130,8 @@ def __init__(
129130
# Callbacks
130131
self.specified_callbacks: List = []
131132
self.callbacks: List[BaseCallback] = []
133+
# Use env-kwargs if eval_env_kwargs was not specified
134+
self.eval_env_kwargs: Dict[str, Any] = eval_env_kwargs or self.env_kwargs
132135
self.save_freq = save_freq
133136
self.eval_freq = eval_freq
134137
self.n_eval_episodes = n_eval_episodes
@@ -604,13 +607,15 @@ def create_envs(self, n_envs: int, eval_env: bool = False, no_log: bool = False)
604607
def make_env(**kwargs) -> gym.Env:
605608
return spec.make(**kwargs)
606609

610+
env_kwargs = self.eval_env_kwargs if eval_env else self.env_kwargs
611+
607612
# On most env, SubprocVecEnv does not help and is quite memory hungry,
608613
# therefore, we use DummyVecEnv by default
609614
env = make_vec_env(
610615
make_env,
611616
n_envs=n_envs,
612617
seed=self.seed,
613-
env_kwargs=self.env_kwargs,
618+
env_kwargs=env_kwargs,
614619
monitor_dir=log_dir,
615620
wrapper_class=self.env_wrapper,
616621
vec_env_cls=self.vec_env_class,

rl_zoo3/train.py

+8
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,13 @@ def train() -> None:
114114
parser.add_argument(
115115
"--env-kwargs", type=str, nargs="+", action=StoreDict, help="Optional keyword argument to pass to the env constructor"
116116
)
117+
parser.add_argument(
118+
"--eval-env-kwargs",
119+
type=str,
120+
nargs="+",
121+
action=StoreDict,
122+
help="Optional keyword argument to pass to the env constructor for evaluation",
123+
)
117124
parser.add_argument(
118125
"-params",
119126
"--hyperparams",
@@ -223,6 +230,7 @@ def train() -> None:
223230
args.save_freq,
224231
args.hyperparams,
225232
args.env_kwargs,
233+
args.eval_env_kwargs,
226234
args.trained_agent,
227235
args.optimize_hyperparameters,
228236
args.storage,

rl_zoo3/version.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.2.0a2
1+
2.2.0a4

tests/test_callbacks.py

+6-14
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import shlex
12
import subprocess
23

34

@@ -6,18 +7,9 @@ def _assert_eq(left, right):
67

78

89
def test_raw_stat_callback(tmp_path):
9-
args = [
10-
"-n",
11-
str(200),
12-
"--algo",
13-
"ppo",
14-
"--env",
15-
"CartPole-v1",
16-
"-params",
17-
"callback:'rl_zoo3.callbacks.RawStatisticsCallback'",
18-
"--tensorboard-log",
19-
f"{tmp_path}",
20-
]
21-
22-
return_code = subprocess.call(["python", "train.py", *args])
10+
cmd = (
11+
f"python train.py -n 200 --algo ppo --env CartPole-v1 --log-folder {tmp_path} "
12+
f"--tensorboard-log {tmp_path} -params callback:\"'rl_zoo3.callbacks.RawStatisticsCallback'\""
13+
)
14+
return_code = subprocess.call(shlex.split(cmd))
2315
_assert_eq(return_code, 0)

tests/test_enjoy.py

+30-62
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
2+
import shlex
23
import subprocess
3-
import sys
44

55
import pytest
66

@@ -23,7 +23,6 @@ def _assert_eq(left, right):
2323
@pytest.mark.slow
2424
def test_trained_agents(trained_model):
2525
algo, env_id = trained_models[trained_model]
26-
args = ["-n", str(N_STEPS), "-f", FOLDER, "--algo", algo, "--env", env_id, "--no-render"]
2726

2827
# Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class
2928
if algo == "her":
@@ -44,69 +43,55 @@ def test_trained_agents(trained_model):
4443

4544
# FIXME: switch to MiniGrid package
4645
if "-MiniGrid-" in trained_model:
47-
# Skip for python 3.7, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/372#issuecomment-1490562332
48-
if sys.version_info[:2] == (3, 7):
49-
pytest.skip("MiniGrid env does not work with Python 3.7")
5046
# FIXME: switch to Gymnsium
5147
return
5248

53-
return_code = subprocess.call(["python", "enjoy.py", *args])
49+
cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {FOLDER} --no-render"
50+
return_code = subprocess.call(shlex.split(cmd))
5451
_assert_eq(return_code, 0)
5552

5653

5754
def test_benchmark(tmp_path):
58-
args = ["-n", str(N_STEPS), "--benchmark-dir", tmp_path, "--test-mode", "--no-hub"]
59-
60-
return_code = subprocess.call(["python", "-m", "rl_zoo3.benchmark", *args])
55+
cmd = f"python -m rl_zoo3.benchmark -n {N_STEPS} --benchmark-dir {tmp_path} --test-mode --no-hub"
56+
return_code = subprocess.call(shlex.split(cmd))
6157
_assert_eq(return_code, 0)
6258

6359

6460
def test_load(tmp_path):
6561
algo, env_id = "a2c", "CartPole-v1"
66-
args = [
67-
"-n",
68-
str(1000),
69-
"--algo",
70-
algo,
71-
"--env",
72-
env_id,
73-
"-params",
74-
"n_envs:1",
75-
"--log-folder",
76-
tmp_path,
77-
"--eval-freq",
78-
str(500),
79-
"--save-freq",
80-
str(500),
81-
"-P", # Enable progress bar
82-
]
8362
# Train and save checkpoints and best model
84-
return_code = subprocess.call(["python", "train.py", *args])
63+
cmd = (
64+
f"python train.py --algo {algo} --env {env_id} -n 1000 -f {tmp_path} "
65+
# Enable progress bar
66+
f"-params n_envs:1 --eval-freq 500 --save-freq 500 -P"
67+
)
68+
return_code = subprocess.call(shlex.split(cmd))
8569
_assert_eq(return_code, 0)
8670

8771
# Load best model
88-
args = ["-n", str(N_STEPS), "-f", tmp_path, "--algo", algo, "--env", env_id, "--no-render"]
89-
# Test with progress bar
90-
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-best", "-P"])
72+
base_cmd = f"python enjoy.py --algo {algo} --env {env_id} -n {N_STEPS} -f {tmp_path} --no-render "
73+
# Enable progress bar
74+
return_code = subprocess.call(shlex.split(base_cmd + "--load-best -P"))
75+
9176
_assert_eq(return_code, 0)
9277

9378
# Load checkpoint
94-
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-checkpoint", str(500)])
79+
return_code = subprocess.call(shlex.split(base_cmd + "--load-checkpoint 500"))
9580
_assert_eq(return_code, 0)
9681

9782
# Load last checkpoint
98-
return_code = subprocess.call(["python", "enjoy.py", *args, "--load-last-checkpoint"])
83+
return_code = subprocess.call(shlex.split(base_cmd + "--load-last-checkpoint"))
9984
_assert_eq(return_code, 0)
10085

10186

10287
def test_record_video(tmp_path):
103-
args = ["-n", "100", "--algo", "sac", "--env", "Pendulum-v1", "-o", str(tmp_path)]
104-
10588
# Skip if no X-Server
10689
if not os.environ.get("DISPLAY"):
10790
pytest.skip("No X-Server")
10891

109-
return_code = subprocess.call(["python", "-m", "rl_zoo3.record_video", *args])
92+
cmd = f"python -m rl_zoo3.record_video -n 100 --algo sac --env Pendulum-v1 -o {tmp_path}"
93+
return_code = subprocess.call(shlex.split(cmd))
94+
11095
_assert_eq(return_code, 0)
11196
video_path = str(tmp_path / "final-model-sac-Pendulum-v1-step-0-to-step-100.mp4")
11297
# File is not empty
@@ -115,41 +100,24 @@ def test_record_video(tmp_path):
115100

116101
def test_record_training(tmp_path):
117102
videos_tmp_path = tmp_path / "videos"
118-
args_training = [
119-
"--algo",
120-
"ppo",
121-
"--env",
122-
"CartPole-v1",
123-
"--log-folder",
124-
str(tmp_path),
125-
"--save-freq",
126-
"4000",
127-
"-n",
128-
"10000",
129-
]
130-
args_recording = [
131-
"--algo",
132-
"ppo",
133-
"--env",
134-
"CartPole-v1",
135-
"--gif",
136-
"-n",
137-
"100",
138-
"-f",
139-
str(tmp_path),
140-
"-o",
141-
str(videos_tmp_path),
142-
]
103+
algo, env_id = "ppo", "CartPole-v1"
143104

144105
# Skip if no X-Server
145106
if not os.environ.get("DISPLAY"):
146107
pytest.skip("No X-Server")
147108

148-
return_code = subprocess.call(["python", "train.py", *args_training])
109+
cmd = f"python train.py -n 10000 --algo {algo} --env {env_id} --log-folder {tmp_path} --save-freq 4000 "
110+
return_code = subprocess.call(shlex.split(cmd))
149111
_assert_eq(return_code, 0)
150112

151-
return_code = subprocess.call(["python", "-m", "rl_zoo3.record_training", *args_recording])
113+
cmd = (
114+
f"python -m rl_zoo3.record_training -n 100 --algo {algo} --env {env_id} "
115+
f"--f {tmp_path} "
116+
f"--gif -o {videos_tmp_path}"
117+
)
118+
return_code = subprocess.call(shlex.split(cmd))
152119
_assert_eq(return_code, 0)
120+
153121
mp4_path = str(videos_tmp_path / "training.mp4")
154122
gif_path = str(videos_tmp_path / "training.gif")
155123
# File is not empty

0 commit comments

Comments
 (0)