Skip to content

Commit f429a9c

Browse files
committed
updated versions in rosinstall | added seconds to model name | removed wandb because its missing in stable baselines branch
1 parent 75e70f9 commit f429a9c

File tree

7 files changed

+31
-14
lines changed

7 files changed

+31
-14
lines changed

.rosinstall

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
- git:
3939
local-name: ../utils/arena-utils
4040
uri: https://github.com/Arena-Rosnav/arena-utils.git
41-
version: v2.1.0
41+
version: v2.2.0
4242

4343
- git:
4444
local-name: ../utils/task-generator
@@ -55,7 +55,7 @@
5555
- git:
5656
local-name: ../planners/rosnav
5757
uri: https://github.com/Arena-Rosnav/rosnav.git
58-
version: v1.1.1
58+
version: v1.1.2
5959

6060
- git:
6161
local-name: ../planners/arena-ros

arena_bringup/launch/start_arena.launch

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,19 @@
99
<arg name="task_id" default="" />
1010
<arg name="app_token" default="" />
1111
<arg name="app_token_key" default="" />
12-
<arg name="task_finished_url" default="" />
12+
<arg name="base_url" default="" />
13+
<arg name="task_finished_endpoint" default="" />
14+
<arg name="new_best_model_endpoint" default="" />
1315

16+
<param name="is_webapp_docker" value="$(arg is_webapp_docker)" />
1417
<param name="task_id" value="$(arg task_id)" />
1518
<param name="app_token" value="$(arg app_token)" />
1619
<param name="app_token_key" value="$(arg app_token_key)" />
17-
<param name="task_finished_url" value="$(arg task_finished_url)" />
20+
<param name="base_url" value="$(arg base_url)" />
21+
<param name="task_finished_endpoint" value="$(arg task_finished_endpoint)" />
22+
<param name="new_best_model_endpoint" value="$(arg new_best_model_endpoint)" />
1823

19-
<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="task_progress_publisher" if="$(eval arg('is_webapp_docker') == true)" />
24+
<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="arena-utils" if="$(eval arg('is_webapp_docker') == true)" />
2025
<!-- -->
2126

2227
<arg name="desired_resets" default="2" />

arena_bringup/launch/start_training.launch

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,17 @@
55
<arg name="task_id" default="" />
66
<arg name="app_token" default="" />
77
<arg name="app_token_key" default="" />
8-
<arg name="task_finished_url" default="" />
8+
<arg name="base_url" default="" />
9+
<arg name="task_finished_endpoint" default="" />
10+
<arg name="new_best_model_endpoint" default="" />
911

12+
<param name="is_webapp_docker" value="$(arg is_webapp_docker)" />
1013
<param name="task_id" value="$(arg task_id)" />
1114
<param name="app_token" value="$(arg app_token)" />
1215
<param name="app_token_key" value="$(arg app_token_key)" />
13-
<param name="task_finished_url" value="$(arg task_finished_url)" />
16+
<param name="base_url" value="$(arg base_url)" />
17+
<param name="task_finished_endpoint" value="$(arg task_finished_endpoint)" />
18+
<param name="new_best_model_endpoint" value="$(arg new_best_model_endpoint)" />
1419

1520
<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="arena-utils" if="$(eval arg('is_webapp_docker') == true)" />
1621
<!-- -->

training/configs/training_config.yaml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ no_gpu: false
1010
### Training Monitoring
1111
monitoring:
1212
# weights and biases logging
13-
use_wandb: true
13+
use_wandb: false
1414
# save evaluation stats during training in log file
1515
eval_log: false
1616

@@ -71,4 +71,3 @@ rl_agent:
7171
m_batch_size: 20
7272
n_epochs: 3
7373
clip_range: 0.22
74-

training/scripts/train_agent.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
from tools.model_utils import init_callbacks, get_ppo_instance
1010
from tools.env_utils import init_envs
1111

12+
def on_shutdown(model):
13+
model.env.close()
14+
sys.exit()
15+
1216

1317
def main():
1418
args, _ = parse_training_args()
19+
1520
config = load_config(args.config)
1621

1722
populate_ros_configs(config)
@@ -47,7 +52,11 @@ def main():
4752
eval_cb = init_callbacks(config, train_env, eval_env, PATHS)
4853
model = get_ppo_instance(config, train_env, PATHS, AgentFactory)
4954

50-
rospy.on_shutdown(model.env.close())
55+
rospy.on_shutdown(lambda: on_shutdown(model))
56+
57+
## Save model once
58+
if not config["debug_mode"]:
59+
model.save(os.path.join(PATHS["model"], "best_model"))
5160

5261
# start training
5362
start = time.time()

training/tools/general.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def generate_agent_name(config: dict) -> str:
220220
:param config (dict): Dict containing the program arguments
221221
"""
222222
if config["rl_agent"]["resume"] is None:
223-
START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M")
223+
START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M_%S")
224224
robot_model = rospy.get_param("robot_model")
225225
architecture_name, encoder_name = config["rl_agent"][
226226
"architecture_name"

training/tools/model_utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23
from typing import Union, Type
34

45
import wandb
@@ -144,9 +145,7 @@ def instantiate_new_model(
144145
"n_epochs": ppo_config["n_epochs"],
145146
"clip_range": ppo_config["clip_range"],
146147
"tensorboard_log": PATHS["tb"],
147-
"use_wandb": False
148-
if config["debug_mode"]
149-
else config["monitoring"]["use_wandb"],
148+
# "use_wandb": False if config["debug_mode"] else config["monitoring"]["use_wandb"],
150149
"verbose": 1,
151150
}
152151

0 commit comments

Comments
 (0)