Skip to content

Commit

Permalink
Last fine tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-lund committed Sep 19, 2021
1 parent ebc06c8 commit c243a51
Showing 1 changed file with 11 additions and 20 deletions.
31 changes: 11 additions & 20 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@
from cli import run_cli_cmnds


def main():
model_name = "blitz5k"
model = neural_net.load(f"../models/{model_name}.pt")

def test():
# env = discrete_env_with_nn(rf.right, model)
# from gym_cartpole_swingup.envs import cartpole_swingup
# env = cartpole_swingup.CartPoleSwingUpV1()
Expand All @@ -29,33 +26,27 @@ def main():
input("continue?")
histories = agents.run(ppo, env, 10)

with open("history_ppo.p", "wb") as f:
dill.dump(histories, f)

evaluation.plot_angles(histories[0], model_name)
evaluation.plot_angles(histories[0], "no model")


def test2():
def main():
model_name = "blitz5k"
model = neural_net.load(f"../models/{model_name}.pt")

env = neural_net.USUCEnvWithNN.create(model, rf.best, "../discrete-usuc-dataset")
env.reset(1)
history = utils.random_actions(env)
evaluation.plot_angles(history, model_name)
evaluation.plot_reward_angle(history)

ppo = agents.create("ppo", env)
agents.train(ppo, total_timesteps=10000)
agents.save(ppo, "../agents/ppo")

def analysis():
# load history for analysis
with open("history_ppo.p", "rb") as f:
history = dill.load(f)
input("continue?")
histories = agents.run(ppo, env, 10)

evaluation.plot_reward_angle(history)
evaluation.plot_angles(histories[0], model_name)
evaluation.plot_reward_angle(histories[0])


if __name__ == "__main__":
test2()
# test()
# main()
# analysis()
run_cli_cmnds()

0 comments on commit c243a51

Please sign in to comment.