From faf9ffb59cd1f1126cc80d99115e50bf8ba2a30b Mon Sep 17 00:00:00 2001 From: ReykCS Date: Sun, 29 Jan 2023 18:43:40 +0100 Subject: [PATCH] added more params for webapp | removed accidental merge conflict relicts --- arena_bringup/launch/start_arena.launch | 12 ++++++++- arena_bringup/launch/start_training.launch | 10 +++++-- training/scripts/train_agent.py | 31 +++------------------- 3 files changed, 22 insertions(+), 31 deletions(-) diff --git a/arena_bringup/launch/start_arena.launch b/arena_bringup/launch/start_arena.launch index 94b5aec2..29f04204 100755 --- a/arena_bringup/launch/start_arena.launch +++ b/arena_bringup/launch/start_arena.launch @@ -7,11 +7,21 @@ + + + + + + + - + + + + diff --git a/arena_bringup/launch/start_training.launch b/arena_bringup/launch/start_training.launch index 12ec4ec2..e8152c06 100644 --- a/arena_bringup/launch/start_training.launch +++ b/arena_bringup/launch/start_training.launch @@ -3,12 +3,18 @@ + + + + + + + - + - diff --git a/training/scripts/train_agent.py b/training/scripts/train_agent.py index 5036bb35..aa6aefbf 100644 --- a/training/scripts/train_agent.py +++ b/training/scripts/train_agent.py @@ -1,22 +1,8 @@ #!/usr/bin/env python -<<<<<<< HEAD import sys, rospy, time -======= -from typing import Type, Union - -import os, sys, rospy, time from std_msgs.msg import Empty -from stable_baselines3 import PPO -from stable_baselines3.common.vec_env import SubprocVecEnv, DummyVecEnv -from stable_baselines3.common.callbacks import ( - EvalCallback, - StopTrainingOnRewardThreshold, -) -from stable_baselines3.common.policies import ActorCriticPolicy, BasePolicy ->>>>>>> dev - from rosnav.model.agent_factory import AgentFactory from tools.argsparser import parse_training_args from tools.general import * @@ -61,6 +47,8 @@ def main(): eval_cb = init_callbacks(config, train_env, eval_env, PATHS) model = get_ppo_instance(config, train_env, PATHS, AgentFactory) + rospy.on_shutdown(model.env.close()) + # start training start = time.time() try: @@ -72,22 +60,10 @@ def main(): except KeyboardInterrupt: print("KeyboardInterrupt..") -<<<<<<< HEAD - rospy.on_shutdown(model.env.close()) - print(f"Time passed: {time.time()-start}s. \n Training script will be terminated..") -======= - model.learn( - total_timesteps=n_timesteps, - callback=eval_cb, - reset_num_timesteps=True, - ) - # update the timesteps the model has trained in total - # update_total_timesteps_json(n_timesteps, PATHS) + print(f"Time passed: {time.time()-start}s. \n Training script will be terminated..") model.env.close() - print(f"Time passed: {time.time()-start}s") - print("Training script will be terminated") publisher = rospy.Publisher("training_finished", Empty, queue_size=10) @@ -95,7 +71,6 @@ def main(): publisher.publish(Empty()) rospy.sleep(0.1) ->>>>>>> dev sys.exit()