Skip to content

Commit

Permalink
updated versions in rosinstall | added seconds to model name | remove…
Browse files Browse the repository at this point in the history
…d wandb because its missing in stable baselines branch
  • Loading branch information
ReykCS committed Jan 30, 2023
1 parent 75e70f9 commit f429a9c
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 14 deletions.
4 changes: 2 additions & 2 deletions .rosinstall
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
- git:
local-name: ../utils/arena-utils
uri: https://github.com/Arena-Rosnav/arena-utils.git
version: v2.1.0
version: v2.2.0

- git:
local-name: ../utils/task-generator
Expand All @@ -55,7 +55,7 @@
- git:
local-name: ../planners/rosnav
uri: https://github.com/Arena-Rosnav/rosnav.git
version: v1.1.1
version: v1.1.2

- git:
local-name: ../planners/arena-ros
Expand Down
11 changes: 8 additions & 3 deletions arena_bringup/launch/start_arena.launch
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,19 @@
<arg name="task_id" default="" />
<arg name="app_token" default="" />
<arg name="app_token_key" default="" />
<arg name="task_finished_url" default="" />
<arg name="base_url" default="" />
<arg name="task_finished_endpoint" default="" />
<arg name="new_best_model_endpoint" default="" />

<param name="is_webapp_docker" value="$(arg is_webapp_docker)" />
<param name="task_id" value="$(arg task_id)" />
<param name="app_token" value="$(arg app_token)" />
<param name="app_token_key" value="$(arg app_token_key)" />
<param name="task_finished_url" value="$(arg task_finished_url)" />
<param name="base_url" value="$(arg base_url)" />
<param name="task_finished_endpoint" value="$(arg task_finished_endpoint)" />
<param name="new_best_model_endpoint" value="$(arg new_best_model_endpoint)" />

<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="task_progress_publisher" if="$(eval arg('is_webapp_docker') == true)" />
<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="arena-utils" if="$(eval arg('is_webapp_docker') == true)" />
<!-- -->

<arg name="desired_resets" default="2" />
Expand Down
9 changes: 7 additions & 2 deletions arena_bringup/launch/start_training.launch
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
<arg name="task_id" default="" />
<arg name="app_token" default="" />
<arg name="app_token_key" default="" />
<arg name="task_finished_url" default="" />
<arg name="base_url" default="" />
<arg name="task_finished_endpoint" default="" />
<arg name="new_best_model_endpoint" default="" />

<param name="is_webapp_docker" value="$(arg is_webapp_docker)" />
<param name="task_id" value="$(arg task_id)" />
<param name="app_token" value="$(arg app_token)" />
<param name="app_token_key" value="$(arg app_token_key)" />
<param name="task_finished_url" value="$(arg task_finished_url)" />
<param name="base_url" value="$(arg base_url)" />
<param name="task_finished_endpoint" value="$(arg task_finished_endpoint)" />
<param name="new_best_model_endpoint" value="$(arg new_best_model_endpoint)" />

<node name="task_progress_publisher" type="task_progress_publisher.py" pkg="arena-utils" if="$(eval arg('is_webapp_docker') == true)" />
<!-- -->
Expand Down
3 changes: 1 addition & 2 deletions training/configs/training_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ no_gpu: false
### Training Monitoring
monitoring:
# weights and biases logging
use_wandb: true
use_wandb: false
# save evaluation stats during training in log file
eval_log: false

Expand Down Expand Up @@ -71,4 +71,3 @@ rl_agent:
m_batch_size: 20
n_epochs: 3
clip_range: 0.22

11 changes: 10 additions & 1 deletion training/scripts/train_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
from tools.model_utils import init_callbacks, get_ppo_instance
from tools.env_utils import init_envs

def on_shutdown(model):
model.env.close()
sys.exit()


def main():
args, _ = parse_training_args()

config = load_config(args.config)

populate_ros_configs(config)
Expand Down Expand Up @@ -47,7 +52,11 @@ def main():
eval_cb = init_callbacks(config, train_env, eval_env, PATHS)
model = get_ppo_instance(config, train_env, PATHS, AgentFactory)

rospy.on_shutdown(model.env.close())
rospy.on_shutdown(lambda: on_shutdown(model))

## Save model once
if not config["debug_mode"]:
model.save(os.path.join(PATHS["model"], "best_model"))

# start training
start = time.time()
Expand Down
2 changes: 1 addition & 1 deletion training/tools/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def generate_agent_name(config: dict) -> str:
:param config (dict): Dict containing the program arguments
"""
if config["rl_agent"]["resume"] is None:
START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M")
START_TIME = dt.now().strftime("%Y_%m_%d__%H_%M_%S")
robot_model = rospy.get_param("robot_model")
architecture_name, encoder_name = config["rl_agent"][
"architecture_name"
Expand Down
5 changes: 2 additions & 3 deletions training/tools/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys
from typing import Union, Type

import wandb
Expand Down Expand Up @@ -144,9 +145,7 @@ def instantiate_new_model(
"n_epochs": ppo_config["n_epochs"],
"clip_range": ppo_config["clip_range"],
"tensorboard_log": PATHS["tb"],
"use_wandb": False
if config["debug_mode"]
else config["monitoring"]["use_wandb"],
# "use_wandb": False if config["debug_mode"] else config["monitoring"]["use_wandb"],
"verbose": 1,
}

Expand Down

0 comments on commit f429a9c

Please sign in to comment.