Skip to content

Commit c52ae8a

Browse files
author
rfal
committed
include SSD and thres SSD in eval_policy
1 parent dafb657 commit c52ae8a

File tree

10 files changed

+99
-102
lines changed

10 files changed

+99
-102
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,4 +54,4 @@ $ python run_stable_baselines3.py -C [experiment config file (required)] -P [num
5454
Example configuration files are provided in the **config** directory, and see [parameters.md](parameters.md) for detailed explanations of common parameters.
5555

5656
## Third Party Libraries
57-
This project uses implementations of A2C, PPO, DQN and QRDQN agents from [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) and [stable-baselines3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib), and makes some modifications to apply to the proposed environment. There are some agent specific parameters in the provided configuration files, please refer to [on_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/on_policy_algorithm.py) ((A2C and PPO)) and [off_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/off_policy_algorithm.py) (DQN and QRDQN) for further information.
57+
This project uses implementations of A2C, PPO, DQN and QR-DQN agents from [stable-baselines3](https://github.com/DLR-RM/stable-baselines3) and [stable-baselines3-contrib](https://github.com/Stable-Baselines-Team/stable-baselines3-contrib), and makes some modifications to apply to the proposed environment. There are some agent specific parameters in the provided configuration files, please refer to [on_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/on_policy_algorithm.py) ((A2C and PPO)) and [off_policy_algorithm.py](https://github.com/RobustFieldAutonomyLab/Stochastic_Road_Network/blob/main/thirdparty/stable_baselines3/common/off_policy_algorithm.py) (DQN and QR-DQN) for further information.

config/config_A2C_Town01_cnn.json

Lines changed: 0 additions & 22 deletions
This file was deleted.

config/config_DQN_Town01_cnn.json

Lines changed: 0 additions & 25 deletions
This file was deleted.

config/config_PPO_Town01_cnn.json

Lines changed: 0 additions & 24 deletions
This file was deleted.

run_stable_baselines3.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -70,28 +70,21 @@ def params_dashboard(params):
7070
print("seed: ",params["base"]["seed"])
7171
print("num_timesteps: ",params["base"]["num_timesteps"])
7272
print("agent: ",params["agent"]["name"])
73-
print("policy: ",params["policy"])
73+
print("network: ",params["policy"])
7474
print("discount: ",params["agent"]["discount"])
7575
print("learning rate: ",params["agent"]["alpha"])
7676
print("map: ",params["environment"]["map_name"])
7777
print("start_state: ",params["environment"]["start_state"])
7878
print("goal_states: ",params["environment"]["goal_states"])
79-
print("crosswalk_states: ",params["environment"]["crosswalk_states"],"\n")
79+
print("crosswalk_states: ",params["environment"]["crosswalk_states"])
80+
if params["agent"]["name"] == "QRDQN":
81+
print("eval policy: ",params["agent"]["eval_policy"])
82+
if params["agent"]["eval_policy"] == "Thresholded_SSD":
83+
print("ssd thres: ",params["agent"]["ssd_thres"])
84+
print("\n")
8085

8186
def run_trial(params,device):
8287

83-
#print("\n====== Trial Setup ======\n")
84-
#print("seed: ",params["base"]["seed"])
85-
#print("num_timesteps: ",params["base"]["num_timesteps"])
86-
#print("agent: ",params["agent"]["name"])
87-
#print("policy: ",params["policy"])
88-
#print("discount: ",params["agent"]["discount"])
89-
#print("learning rate: ",params["agent"]["alpha"])
90-
#print("map: ",params["environment"]["map_name"])
91-
#print("start_state: ",params["environment"]["start_state"])
92-
#print("goal_states: ",params["environment"]["goal_states"])
93-
#print("crosswalk_states: ",params["environment"]["crosswalk_states"],"\n")
94-
9588
lr = params["agent"]["alpha"]
9689
sd = params["base"]["seed"]
9790
cw = params["environment"]["crosswalk_states"]
@@ -129,7 +122,7 @@ def run_trial(params,device):
129122
evaluate_env.reset()
130123

131124
if params["agent"]["name"] == "QRDQN":
132-
save_dir = os.path.join(params["save_dir"],params["agent"]["name"],params["environment"]["map_name"],params["policy"],stoc,"buffer_"+str(params["agent"]["buffer_size"]),"n_quantile_"+str(params["agent"]["n_quantiles"]),"lr_"+str(lr),"seed_"+str(sd))
125+
save_dir = os.path.join(params["save_dir"],params["agent"]["name"],params["environment"]["map_name"],params["policy"],params["agent"]["eval_policy"],stoc,"buffer_"+str(params["agent"]["buffer_size"]),"n_quantile_"+str(params["agent"]["n_quantiles"]),"lr_"+str(lr),"seed_"+str(sd))
133126
else:
134127
save_dir = os.path.join(params["save_dir"],params["agent"]["name"],params["environment"]["map_name"],params["policy"],stoc,"buffer_"+str(params["agent"]["buffer_size"]),"lr_"+str(lr),"seed_"+str(sd))
135128

@@ -144,7 +137,8 @@ def run_trial(params,device):
144137
policy_args = {"normalize_images":False}
145138
else:
146139
raise RuntimeError("The network strucutre is not available")
147-
140+
141+
eval_args = {}
148142
if params["agent"]["name"] == "PPO":
149143
model = PPO(params["policy"],
150144
behave_env,
@@ -183,6 +177,9 @@ def run_trial(params,device):
183177
device=device)
184178
elif params["agent"]["name"] == "QRDQN":
185179
policy_args["n_quantiles"] = params["agent"]["n_quantiles"]
180+
eval_args["eval_policy"] = params["agent"]["eval_policy"]
181+
if params["agent"]["eval_policy"] == "Thresholded_SSD":
182+
eval_args["ssd_thres"] = params["agent"]["ssd_thres"]
186183
model = QRDQN(params["policy"],
187184
behave_env,
188185
verbose=1,
@@ -198,23 +195,17 @@ def run_trial(params,device):
198195
device=device)
199196
else:
200197
raise RuntimeError("The agent is not available.")
201-
202-
model.learn(total_timesteps=params["base"]["num_timesteps"], eval_env=evaluate_env, eval_freq=params["base"]["eval_freq"], n_eval_episodes=1, eval_log_path=save_dir)
203-
204-
# check number of steps since last reset of env
205-
#behave_count = behave_env._get_count()
206-
#evaluate_count = evaluate_env._get_count()
207-
#count_file = os.path.join(save_dir,"count.txt")
208-
#np.savetxt(count_file,[behave_count,evaluate_count],fmt="%d")
209198

210-
# save all paths in evaluation
199+
model.learn(total_timesteps=params["base"]["num_timesteps"], eval_env=evaluate_env, eval_freq=params["base"]["eval_freq"], n_eval_episodes=1, eval_log_path=save_dir, **eval_args)
200+
201+
# save all paths in evaluations
211202
all_eval_paths = evaluate_env.get_all_paths()
212203
paths_file = os.path.join(save_dir,"eval_paths.csv")
213204
with open(paths_file, "w", newline="") as f:
214205
write = csv.writer(f)
215206
write.writerows(all_eval_paths)
216207

217-
# save all quantiles in evalution (for QR-DQN agent)
208+
# save all quantiles in evalutions (for QR-DQN agent)
218209
if params["agent"]["name"] == "QRDQN":
219210
all_eval_q = evaluate_env.get_quantiles()
220211
np.save(os.path.join(save_dir,"eval_quantiles.npy"),all_eval_q)

thirdparty/sb3_contrib/qrdqn/qrdqn.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@ def learn(
250250
tb_log_name: str = "QRDQN",
251251
eval_log_path: Optional[str] = None,
252252
reset_num_timesteps: bool = True,
253+
##### local modification #####
254+
eval_policy: str = "Greedy",
255+
ssd_thres: float = 1e-03
253256
) -> OffPolicyAlgorithm:
254257

255258
return super(QRDQN, self).learn(
@@ -262,6 +265,9 @@ def learn(
262265
tb_log_name=tb_log_name,
263266
eval_log_path=eval_log_path,
264267
reset_num_timesteps=reset_num_timesteps,
268+
##### local modification #####
269+
eval_policy=eval_policy,
270+
ssd_thres=ssd_thres
265271
)
266272

267273
def _excluded_save_params(self) -> List[str]:

thirdparty/stable_baselines3/common/base_class.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ def _init_callback(
350350
eval_freq: int = 10000,
351351
n_eval_episodes: int = 5,
352352
log_path: Optional[str] = None,
353+
##### local modification #####
354+
eval_policy: str = "Greedy",
355+
ssd_thres: float = 1e-03
353356
) -> BaseCallback:
354357
"""
355358
:param callback: Callback(s) called at every step with state of the algorithm.
@@ -375,6 +378,9 @@ def _init_callback(
375378
log_path=log_path,
376379
eval_freq=eval_freq,
377380
n_eval_episodes=n_eval_episodes,
381+
##### local modification #####
382+
eval_policy=eval_policy,
383+
ssd_thres=ssd_thres
378384
)
379385
callback = CallbackList([callback, eval_callback])
380386

@@ -391,6 +397,9 @@ def _setup_learn(
391397
log_path: Optional[str] = None,
392398
reset_num_timesteps: bool = True,
393399
tb_log_name: str = "run",
400+
##### local modification #####
401+
eval_policy: str = "Greedy",
402+
ssd_thres: float = 1e-03
394403
) -> Tuple[int, BaseCallback]:
395404
"""
396405
Initialize different variables needed for training.
@@ -442,7 +451,8 @@ def _setup_learn(
442451
self._logger = utils.configure_logger(self.verbose, self.tensorboard_log, tb_log_name, reset_num_timesteps)
443452

444453
# Create eval callback if needed
445-
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path)
454+
##### local modification #####
455+
callback = self._init_callback(callback, eval_env, eval_freq, n_eval_episodes, log_path, eval_policy, ssd_thres)
446456

447457
return total_timesteps, callback
448458

thirdparty/stable_baselines3/common/callbacks.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,9 @@ def __init__(
304304
render: bool = False,
305305
verbose: int = 1,
306306
warn: bool = True,
307+
##### local modification #####
308+
eval_policy: str = "Greedy",
309+
ssd_thres: float = 1e-03
307310
):
308311
super(EvalCallback, self).__init__(callback_on_new_best, verbose=verbose)
309312
self.n_eval_episodes = n_eval_episodes
@@ -313,6 +316,9 @@ def __init__(
313316
self.deterministic = deterministic
314317
self.render = render
315318
self.warn = warn
319+
##### local modification #####
320+
self.eval_policy = eval_policy
321+
self.ssd_thres = ssd_thres
316322

317323
# Convert to VecEnv for consistency
318324
if not isinstance(eval_env, VecEnv):
@@ -384,6 +390,9 @@ def _on_step(self) -> bool:
384390
return_episode_rewards=True,
385391
warn=self.warn,
386392
callback=self._log_success_callback,
393+
##### local modification #####
394+
eval_policy=self.eval_policy,
395+
ssd_thres=self.ssd_thres
387396
)
388397

389398
if self.log_path is not None:
@@ -408,6 +417,11 @@ def _on_step(self) -> bool:
408417
mean_reward, std_reward = np.mean(episode_rewards), np.std(episode_rewards)
409418
mean_ep_length, std_ep_length = np.mean(episode_lengths), np.std(episode_lengths)
410419
self.last_mean_reward = mean_reward
420+
421+
if self.eval_policy == "Thresholded_SSD":
422+
print("Eval policy: Thresholded SSD, ",f"Mean threshold: {self.ssd_thres:.1f}")
423+
else:
424+
print("Eval policy: ",self.eval_policy)
411425

412426
if self.verbose > 0:
413427
print(f"Eval num_timesteps={self.num_timesteps}, " f"episode_reward={mean_reward:.2f} +/- {std_reward:.2f}")

thirdparty/stable_baselines3/common/evaluation.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,24 @@
77
from stable_baselines3.common import base_class
88
from stable_baselines3.common.vec_env import DummyVecEnv, VecEnv, VecMonitor, is_vecenv_wrapped
99

10+
##### local modification #####
11+
def ssd_policy(quantiles:np.ndarray, use_threshold:bool=False, mean_threshold:float=1e-03):
12+
means = np.mean(quantiles,axis=0)
13+
sort_idx = np.argsort(-1*means)
14+
best_1 = sort_idx[0]
15+
best_2 = sort_idx[1]
16+
if means[best_1] - means[best_2] > mean_threshold:
17+
return best_1
18+
else:
19+
if use_threshold:
20+
signed_second_moment = -1 * np.var(quantiles,axis=0)
21+
else:
22+
signed_second_moment = -1 * np.mean(quantiles**2,axis=0)
23+
action = best_1
24+
if signed_second_moment[best_2] > signed_second_moment[best_1]:
25+
action = best_2
26+
return action
27+
1028

1129
def evaluate_policy(
1230
model: "base_class.BaseAlgorithm",
@@ -18,6 +36,9 @@ def evaluate_policy(
1836
reward_threshold: Optional[float] = None,
1937
return_episode_rewards: bool = False,
2038
warn: bool = True,
39+
##### local modification #####
40+
eval_policy: str = "Greedy",
41+
ssd_thres: float = 1e-03
2142
) -> Union[Tuple[float, float], Tuple[List[float], List[int]]]:
2243
"""
2344
Runs policy for ``n_eval_episodes`` episodes and returns average reward.
@@ -70,9 +91,9 @@ def evaluate_policy(
7091
)
7192

7293
##### local modification #####
73-
# get quantiles prediction for all state action pair if the agent is QR-DQN
94+
# store quantiles prediction for all state action pair if the agent is QR-DQN
7495
if env.save_q_vals:
75-
print("predicting quantiles (QR-DQN)")
96+
print("saving quantiles (QR-DQN)")
7697
all_quantiles = []
7798
for i in range(env.num_states):
7899
obs = env.get_obs_at_state(i)
@@ -94,7 +115,21 @@ def evaluate_policy(
94115
observations = env.reset()
95116
states = None
96117
while (episode_counts < episode_count_targets).any():
97-
actions, states = model.predict(observations, state=states, deterministic=deterministic)
118+
##### local modification #####
119+
if eval_policy == "Greedy":
120+
actions, states = model.predict(observations, state=states, deterministic=deterministic)
121+
# TODO: consider multi environments case
122+
elif eval_policy == "SSD":
123+
q_vals = model.predict_quantiles(observations)
124+
actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0])])
125+
states = None
126+
elif eval_policy == "Thresholded_SSD":
127+
q_vals = model.predict_quantiles(observations)
128+
actions = np.array([ssd_policy(q_vals.cpu().data.numpy()[0],use_threshold=True,mean_threshold=ssd_thres)])
129+
states = None
130+
else:
131+
raise RuntimeError("The evaluation policy is not available.")
132+
98133
observations, rewards, dones, infos = env.step(actions)
99134
##### local modification #####
100135
#current_rewards += rewards

thirdparty/stable_baselines3/common/off_policy_algorithm.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,9 @@ def _setup_learn(
278278
log_path: Optional[str] = None,
279279
reset_num_timesteps: bool = True,
280280
tb_log_name: str = "run",
281+
##### local modification #####
282+
eval_policy: str = "Greedy",
283+
ssd_thres: float = 1e-03
281284
) -> Tuple[int, BaseCallback]:
282285
"""
283286
cf `BaseAlgorithm`.
@@ -320,6 +323,9 @@ def _setup_learn(
320323
log_path,
321324
reset_num_timesteps,
322325
tb_log_name,
326+
##### local modification #####
327+
eval_policy,
328+
ssd_thres
323329
)
324330

325331
def learn(
@@ -333,6 +339,9 @@ def learn(
333339
tb_log_name: str = "run",
334340
eval_log_path: Optional[str] = None,
335341
reset_num_timesteps: bool = True,
342+
##### local modification #####
343+
eval_policy: str = "Greedy",
344+
ssd_thres: float = 1e-03
336345
) -> "OffPolicyAlgorithm":
337346

338347
total_timesteps, callback = self._setup_learn(
@@ -344,6 +353,9 @@ def learn(
344353
eval_log_path,
345354
reset_num_timesteps,
346355
tb_log_name,
356+
##### local modification #####
357+
eval_policy,
358+
ssd_thres
347359
)
348360

349361
callback.on_training_start(locals(), globals())

0 commit comments

Comments
 (0)