Skip to content

Commit f0fd5ed

Browse files
committed
Modified Reward Function, added local run functionality to TrainDispatcher.py, added functionality to save best model during training
1 parent 4b4fdcf commit f0fd5ed

12 files changed

+66
-52
lines changed

SBAgent/EvaluateExperiment.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import argparse
66
import json
77
import numpy as np
8+
import random
89
from envs.utils.EnvBuilder import EnvBuilder
910
from stable_baselines3 import PPO
1011
from tqdm import tqdm
@@ -18,6 +19,9 @@
1819
args = parser.parse_args()
1920

2021

22+
np.random.seed(42)
23+
random.seed(42)
24+
2125
with open(args.experimentConfigFile, 'r') as f:
2226
experimentConfig = json.load(f)
2327

@@ -28,7 +32,7 @@
2832
print(f"Running Evaluation on {experimentName}")
2933

3034
env = EnvBuilder.buildEnvFromConfig(os.path.join('..', 'configs', configFileName), gui=args.gui)
31-
agent = PPO.load(os.path.join('models', modelName))
35+
agent = PPO.load(os.path.join('models', modelName, 'best_model'))
3236

3337
totalTrials = args.trials
3438
successfulTrials = 0

SBAgent/EvaluateModel.py

+18-21
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,23 @@
77
from envs.utils.EnvBuilder import EnvBuilder
88
from stable_baselines3 import PPO
99
from tqdm import tqdm
10+
from tabulate import tabulate
1011

1112
parser = argparse.ArgumentParser()
1213
parser.add_argument("configFileName", help="Name of the environment config file.", type=str)
13-
parser.add_argument("inputModelName", help="(base|finetuned + )Name of the model to load.", type=str)
14-
parser.add_argument("--trials", default=100, help="Number of episodes to evaluate for.", type=int)
15-
parser.add_argument("--gui", action=argparse.BooleanOptionalAction, help="Whether or not to show GUI")
16-
14+
parser.add_argument("inputModelPath", help="Path to model to evaluate", type=str)
15+
parser.add_argument("-t", "--trials", default=100, help="Number of episodes to evaluate for.", type=int)
16+
parser.add_argument('--gui', action='store_true', help='Enable GUI')
17+
parser.add_argument('--no-gui', action='store_false', dest='gui', help='Disable GUI')
1718
args = parser.parse_args()
1819

1920
configFileName = args.configFileName
20-
modelName = args.outputModelName
21+
modelName = args.inputModelPath
2122

2223
env = EnvBuilder.buildEnvFromConfig(os.path.join('..', 'configs', configFileName), gui=args.gui)
23-
agent = PPO.load(os.path.join('models', modelName))
24+
agent = PPO.load(modelName)
2425

25-
totalTrials = args.trails
26+
totalTrials = args.trials
2627
successfulTrials = 0
2728
rewards = []
2829
durations = []
@@ -42,7 +43,7 @@
4243

4344
if info['success']:
4445
successfulTrials += 1
45-
if info['reason'] == "collision":
46+
elif info['reason'] == "collision":
4647
nCollisions +=1
4748
else:
4849
incompleteDistances.append(np.linalg.norm(obs[:(obs.shape[0]//2)]))
@@ -52,17 +53,13 @@
5253

5354
env.close()
5455

56+
evaluationResults = {
57+
'Success Rate': f"{successfulTrials/totalTrials * 100:.2f}%",
58+
'Collision Rate': f"{nCollisions/totalTrials * 100:.2f}%",
59+
'Mean Incompletion Distance': f"{sum(incompleteDistances)/len(incompleteDistances):.2f}m" if len(incompleteDistances) > 0 else "N/A",
60+
'Mean Reward': f"{sum(rewards)/len(rewards):.2f}",
61+
'Mean Episode Length': f"{sum(durations)/len(durations)}"
62+
}
5563

56-
57-
58-
print(f"---------------------------------------------------------")
59-
print(f"EVALUATION STATISTICS")
60-
print()
61-
print(f"Success Rate: {successfulTrials/totalTrials * 100:.2f}%")
62-
print(f"Mean Reward: {sum(rewards)/len(rewards):.2f}")
63-
print(f"Minimum Reward: {min(rewards):.2f}")
64-
print(f"Maximium Reward: {max(rewards):.2f}")
65-
print(f"Mean Episode Duration: {sum(durations)/len(durations):.2f} steps")
66-
print(f"Shortest Episode: {min(durations)} steps")
67-
print(f"Longest Episode: {max(durations)} steps")
68-
print(f"---------------------------------------------------------")
64+
evaluationTable = [[k, v] for k,v in evaluationResults.items()]
65+
print(tabulate(evaluationTable, headers=["Metric", "Value"], tablefmt='github'))

SBAgent/TrainModel.py

+8-9
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import argparse
66
from stable_baselines3 import PPO
7-
from stable_baselines3.common.callbacks import CheckpointCallback
7+
from stable_baselines3.common.callbacks import EvalCallback
88
from envs.utils.EnvBuilder import EnvBuilder
99

1010

@@ -22,12 +22,11 @@
2222

2323
env = EnvBuilder.buildEnvFromConfig(os.path.join('..', 'configs', configFileName), gui=False)
2424

25-
checkpoint_callback = CheckpointCallback(
26-
save_freq=1000000,
27-
save_path=os.path.join("checkpoints", modelName),
28-
name_prefix=f"chkpt",
29-
)
25+
eval_callback = EvalCallback(env, best_model_save_path=os.path.join('models', modelName),
26+
log_path=os.path.join('sbEvalLogs', modelName),
27+
eval_freq=100_000, deterministic=True, render=False)
28+
29+
agent = PPO("MlpPolicy", env, verbose=1, tensorboard_log=os.path.join('logs', modelName))
30+
agent.learn(n_steps, callback=eval_callback, tb_log_name="train_logs")
31+
3032

31-
agent = PPO('MlpPolicy', env, verbose=1, tensorboard_log=os.path.join('logs', modelName))
32-
agent.learn(n_steps, callback=checkpoint_callback, tb_log_name="train_logs")
33-
agent.save(os.path.join('models', modelName))
File renamed without changes.

TrainDispatcher.py

+30-17
Original file line numberDiff line numberDiff line change
@@ -8,29 +8,42 @@
88
parser = argparse.ArgumentParser()
99
parser.add_argument("experimentConfigFile", help="Experiment Config File Path")
1010
parser.add_argument("-s", "--steps", default=2_000_000, help="Number of timesteps to train for", type=int)
11+
parser.add_argument('--local', action='store_true', help='Run on Local Machine')
1112
args = parser.parse_args()
1213

13-
with open('trainScriptTemplate.sh', 'r') as f:
14-
script = ''.join(f.readlines())
1514

16-
with open(args.experimentConfigFile, 'r') as f:
17-
experimentConfig = json.load(f)
15+
if args.local:
16+
with open(args.experimentConfigFile, 'r') as f:
17+
experimentConfig = json.load(f)
1818

19-
experimentName = experimentConfig["name"]
20-
envConfig = experimentConfig["trainParameters"]["config"]
21-
modelName = experimentConfig["trainParameters"]["outputModelName"]
19+
experimentName = experimentConfig["name"]
20+
envConfig = experimentConfig["trainParameters"]["config"]
21+
modelName = experimentConfig["trainParameters"]["outputModelName"]
2222

23-
script = script.replace("{outputFile}", f"jobOutputs/{experimentName}_train_output.txt")
24-
script = script.replace("{jobName}", f"{experimentName}_train")
25-
script = script.replace("{configFile}", envConfig)
26-
script = script.replace("{outputModelName}", modelName)
27-
script = script.replace("{steps}", str(args.steps))
23+
os.chdir('SBAgent')
24+
os.system(f"python TrainModel.py {envConfig} {modelName} --steps {args.steps}")
25+
else:
26+
with open('trainScriptTemplate.sh', 'r') as f:
27+
script = ''.join(f.readlines())
2828

29-
tmp = tempfile.NamedTemporaryFile()
29+
with open(args.experimentConfigFile, 'r') as f:
30+
experimentConfig = json.load(f)
3031

31-
with open(tmp.name, 'w') as f:
32-
f.write(script)
32+
experimentName = experimentConfig["name"]
33+
envConfig = experimentConfig["trainParameters"]["config"]
34+
modelName = experimentConfig["trainParameters"]["outputModelName"]
3335

34-
print(f"Dispatching Train Job for {experimentName}")
36+
script = script.replace("{outputFile}", f"jobOutputs/{experimentName}_train_output.txt")
37+
script = script.replace("{jobName}", f"{experimentName}_train")
38+
script = script.replace("{configFile}", envConfig)
39+
script = script.replace("{outputModelName}", modelName)
40+
script = script.replace("{steps}", str(args.steps))
3541

36-
os.system(f"sbatch {tmp.name}")
42+
tmp = tempfile.NamedTemporaryFile()
43+
44+
with open(tmp.name, 'w') as f:
45+
f.write(script)
46+
47+
print(f"Dispatching Train Job for {experimentName}")
48+
49+
os.system(f"sbatch {tmp.name}")

envs/ObstacleAviary.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ObstacleAviary(BaseSingleAgentAviary):
1313

1414
CLOSE_TO_FINISH_REWARD = 5
1515
SUCCESS_REWARD = 1000
16-
COLLISION_PENALTY = -100
16+
COLLISION_PENALTY = -1000
1717

1818
SUCCESS_EPSILON = 0.1
1919

@@ -226,8 +226,8 @@ def _computeReward(self):
226226
if np.linalg.norm(self.targetPos - pos) < ObstacleAviary.SUCCESS_EPSILON:
227227
return ObstacleAviary.SUCCESS_REWARD
228228

229-
if np.linalg.norm(self.targetPos - pos) < ObstacleAviary.MINOR_SAFETY_BOUND_RADIUS:
230-
return ObstacleAviary.CLOSE_TO_FINISH_REWARD
229+
# if np.linalg.norm(self.targetPos - pos) < ObstacleAviary.MINOR_SAFETY_BOUND_RADIUS:
230+
# return ObstacleAviary.CLOSE_TO_FINISH_REWARD
231231

232232
offsetToClosestObstacle = self._computeOffsetToClosestObstacle()
233233

@@ -239,7 +239,8 @@ def _computeReward(self):
239239
majorBoundBreach = distToClosestObstacle < ObstacleAviary.MAJOR_SAFETY_BOUND_RADIUS
240240
minorBoundBreach = distToClosestObstacle < ObstacleAviary.MINOR_SAFETY_BOUND_RADIUS
241241

242-
return 0.5*np.linalg.norm(pos - self.initPos) - 2*np.linalg.norm(self.targetPos - pos) - 10*majorBoundBreach - 2*minorBoundBreach
242+
# return 0.5*np.linalg.norm(pos - self.initPos) -*np.linalg.norm(self.targetPos - pos) - 10*majorBoundBreach - 2*minorBoundBreach
243+
return -2*np.linalg.norm(self.targetPos - pos) - 1*majorBoundBreach - 0.1*minorBoundBreach
243244

244245

245246
def _computeOffsetToClosestObstacle(self):
-53 Bytes
Binary file not shown.
Binary file not shown.
11 Bytes
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)