1
1
import os
2
+ import shlex
2
3
import subprocess
3
- import sys
4
4
5
5
import pytest
6
6
@@ -23,7 +23,6 @@ def _assert_eq(left, right):
23
23
@pytest .mark .slow
24
24
def test_trained_agents (trained_model ):
25
25
algo , env_id = trained_models [trained_model ]
26
- args = ["-n" , str (N_STEPS ), "-f" , FOLDER , "--algo" , algo , "--env" , env_id , "--no-render" ]
27
26
28
27
# Since SB3 >= 1.1.0, HER is no more an algorithm but a replay buffer class
29
28
if algo == "her" :
@@ -44,69 +43,55 @@ def test_trained_agents(trained_model):
44
43
45
44
# FIXME: switch to MiniGrid package
46
45
if "-MiniGrid-" in trained_model :
47
- # Skip for python 3.7, see https://github.com/DLR-RM/rl-baselines3-zoo/pull/372#issuecomment-1490562332
48
- if sys .version_info [:2 ] == (3 , 7 ):
49
- pytest .skip ("MiniGrid env does not work with Python 3.7" )
50
46
# FIXME: switch to Gymnsium
51
47
return
52
48
53
- return_code = subprocess .call (["python" , "enjoy.py" , * args ])
49
+ cmd = f"python enjoy.py --algo { algo } --env { env_id } -n { N_STEPS } -f { FOLDER } --no-render"
50
+ return_code = subprocess .call (shlex .split (cmd ))
54
51
_assert_eq (return_code , 0 )
55
52
56
53
57
54
def test_benchmark (tmp_path ):
58
- args = ["-n" , str (N_STEPS ), "--benchmark-dir" , tmp_path , "--test-mode" , "--no-hub" ]
59
-
60
- return_code = subprocess .call (["python" , "-m" , "rl_zoo3.benchmark" , * args ])
55
+ cmd = f"python -m rl_zoo3.benchmark -n { N_STEPS } --benchmark-dir { tmp_path } --test-mode --no-hub"
56
+ return_code = subprocess .call (shlex .split (cmd ))
61
57
_assert_eq (return_code , 0 )
62
58
63
59
64
60
def test_load (tmp_path ):
65
61
algo , env_id = "a2c" , "CartPole-v1"
66
- args = [
67
- "-n" ,
68
- str (1000 ),
69
- "--algo" ,
70
- algo ,
71
- "--env" ,
72
- env_id ,
73
- "-params" ,
74
- "n_envs:1" ,
75
- "--log-folder" ,
76
- tmp_path ,
77
- "--eval-freq" ,
78
- str (500 ),
79
- "--save-freq" ,
80
- str (500 ),
81
- "-P" , # Enable progress bar
82
- ]
83
62
# Train and save checkpoints and best model
84
- return_code = subprocess .call (["python" , "train.py" , * args ])
63
+ cmd = (
64
+ f"python train.py --algo { algo } --env { env_id } -n 1000 -f { tmp_path } "
65
+ # Enable progress bar
66
+ f"-params n_envs:1 --eval-freq 500 --save-freq 500 -P"
67
+ )
68
+ return_code = subprocess .call (shlex .split (cmd ))
85
69
_assert_eq (return_code , 0 )
86
70
87
71
# Load best model
88
- args = ["-n" , str (N_STEPS ), "-f" , tmp_path , "--algo" , algo , "--env" , env_id , "--no-render" ]
89
- # Test with progress bar
90
- return_code = subprocess .call (["python" , "enjoy.py" , * args , "--load-best" , "-P" ])
72
+ base_cmd = f"python enjoy.py --algo { algo } --env { env_id } -n { N_STEPS } -f { tmp_path } --no-render "
73
+ # Enable progress bar
74
+ return_code = subprocess .call (shlex .split (base_cmd + "--load-best -P" ))
75
+
91
76
_assert_eq (return_code , 0 )
92
77
93
78
# Load checkpoint
94
- return_code = subprocess .call ([ "python" , "enjoy.py" , * args , "--load-checkpoint" , str ( 500 )] )
79
+ return_code = subprocess .call (shlex . split ( base_cmd + "--load-checkpoint 500" ) )
95
80
_assert_eq (return_code , 0 )
96
81
97
82
# Load last checkpoint
98
- return_code = subprocess .call ([ "python" , "enjoy.py" , * args , "--load-last-checkpoint" ] )
83
+ return_code = subprocess .call (shlex . split ( base_cmd + "--load-last-checkpoint" ) )
99
84
_assert_eq (return_code , 0 )
100
85
101
86
102
87
def test_record_video (tmp_path ):
103
- args = ["-n" , "100" , "--algo" , "sac" , "--env" , "Pendulum-v1" , "-o" , str (tmp_path )]
104
-
105
88
# Skip if no X-Server
106
89
if not os .environ .get ("DISPLAY" ):
107
90
pytest .skip ("No X-Server" )
108
91
109
- return_code = subprocess .call (["python" , "-m" , "rl_zoo3.record_video" , * args ])
92
+ cmd = f"python -m rl_zoo3.record_video -n 100 --algo sac --env Pendulum-v1 -o { tmp_path } "
93
+ return_code = subprocess .call (shlex .split (cmd ))
94
+
110
95
_assert_eq (return_code , 0 )
111
96
video_path = str (tmp_path / "final-model-sac-Pendulum-v1-step-0-to-step-100.mp4" )
112
97
# File is not empty
@@ -115,41 +100,24 @@ def test_record_video(tmp_path):
115
100
116
101
def test_record_training (tmp_path ):
117
102
videos_tmp_path = tmp_path / "videos"
118
- args_training = [
119
- "--algo" ,
120
- "ppo" ,
121
- "--env" ,
122
- "CartPole-v1" ,
123
- "--log-folder" ,
124
- str (tmp_path ),
125
- "--save-freq" ,
126
- "4000" ,
127
- "-n" ,
128
- "10000" ,
129
- ]
130
- args_recording = [
131
- "--algo" ,
132
- "ppo" ,
133
- "--env" ,
134
- "CartPole-v1" ,
135
- "--gif" ,
136
- "-n" ,
137
- "100" ,
138
- "-f" ,
139
- str (tmp_path ),
140
- "-o" ,
141
- str (videos_tmp_path ),
142
- ]
103
+ algo , env_id = "ppo" , "CartPole-v1"
143
104
144
105
# Skip if no X-Server
145
106
if not os .environ .get ("DISPLAY" ):
146
107
pytest .skip ("No X-Server" )
147
108
148
- return_code = subprocess .call (["python" , "train.py" , * args_training ])
109
+ cmd = f"python train.py -n 10000 --algo { algo } --env { env_id } --log-folder { tmp_path } --save-freq 4000 "
110
+ return_code = subprocess .call (shlex .split (cmd ))
149
111
_assert_eq (return_code , 0 )
150
112
151
- return_code = subprocess .call (["python" , "-m" , "rl_zoo3.record_training" , * args_recording ])
113
+ cmd = (
114
+ f"python -m rl_zoo3.record_training -n 100 --algo { algo } --env { env_id } "
115
+ f"--f { tmp_path } "
116
+ f"--gif -o { videos_tmp_path } "
117
+ )
118
+ return_code = subprocess .call (shlex .split (cmd ))
152
119
_assert_eq (return_code , 0 )
120
+
153
121
mp4_path = str (videos_tmp_path / "training.mp4" )
154
122
gif_path = str (videos_tmp_path / "training.gif" )
155
123
# File is not empty
0 commit comments