Skip to content

Commit

Permalink
n-step q-learning draft
Browse files Browse the repository at this point in the history
  • Loading branch information
opasche committed Nov 17, 2022
1 parent a6f9499 commit 82fc928
Show file tree
Hide file tree
Showing 40 changed files with 109 additions and 14 deletions.
Binary file not shown.
11 changes: 9 additions & 2 deletions Examples/CartPole/CartPoleFeatures/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,11 +108,18 @@ def play(agent, environement='CartPole-v1', n_episodes=5, n_timesteps=1000, plot



agent, reward_list = train(environement='CartPole-v1', n_episodes=10000, n_timesteps=500,
Hidden_vect=[32,32], discount_rate = 0.999, lr = 1e-3,
agent, reward_list = train(environement='CartPole-v1', n_episodes=5000, n_timesteps=500,
Hidden_vect=[32,24], discount_rate = 0.999, lr = 1e-3,
max_exploration_rate = 1, exploration_decay_rate = 0.001, min_exploration_rate = 0.01,#0.001,
warm_start_path=None, render_mode="rgb_array")
agent.save("model_weights/DQN_cartpole_weights_last.pt")
replay_memory = agent.replay_memory

#agent.load("model_weights/DQN_cartpole_weights_best_32_24.pt")

rew_play = play(agent, environement='CartPole-v1', n_episodes=5, n_timesteps=1000, plot_rewards=False, render_mode="human")
print(rew_play)




Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
108 changes: 98 additions & 10 deletions Examples/FrozenLake/reinforcement-learning-frozen-lake.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-13T17:49:00.859832Z",
"start_time": "2022-11-13T17:48:59.118357Z"
"end_time": "2022-11-17T17:47:07.528193Z",
"start_time": "2022-11-17T17:47:06.661799Z"
}
},
"outputs": [],
Expand All @@ -32,8 +32,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-13T17:49:03.414396Z",
"start_time": "2022-11-13T17:49:00.971074Z"
"end_time": "2022-11-17T17:47:09.057967Z",
"start_time": "2022-11-17T17:47:08.038928Z"
}
},
"outputs": [],
Expand All @@ -54,8 +54,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-13T17:49:04.372662Z",
"start_time": "2022-11-13T17:49:04.344113Z"
"end_time": "2022-11-17T17:47:12.755822Z",
"start_time": "2022-11-17T17:47:12.727898Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -96,8 +96,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-13T17:49:05.590320Z",
"start_time": "2022-11-13T17:49:05.582172Z"
"end_time": "2022-11-17T17:47:14.973047Z",
"start_time": "2022-11-17T17:47:14.948087Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -1438,6 +1438,94 @@
"rewards3, agent3, env3 = frozen_lake_play(env3, agent3, 4,\n",
" desc=desc3, map_name=map_name3, is_slippery=is_slippery3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# n-step Q-learning"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-17T17:42:58.385626Z",
"start_time": "2022-11-17T17:42:58.372627Z"
}
},
"outputs": [],
"source": [
"def frozen_lake_train_nstep(n_episodes, n_timesteps=int(1e10), n_steps=5,\n",
" discount_rate = 0.99, lr = 0.1,\n",
" max_exploration_rate = 1, exploration_decay_rate = 0.001, min_exploration_rate = 0.001,\n",
" desc=None, map_name=\"4x4\", is_slippery=True, verbatim=0):\n",
" \n",
" #create environment frozen lake\n",
" env = gym.make(\"FrozenLake-v1\", desc=desc, map_name=map_name, is_slippery=is_slippery, render_mode=\"rgb_array_list\")\n",
" \n",
" n_states = env.observation_space.n\n",
" n_actions = env.action_space.n\n",
" \n",
" #create agent\n",
" agent = nstep_Q_agent(n_states=n_states, n_actions=n_actions, n_steps=n_steps,\n",
" discount_rate=discount_rate, lr=lr,\n",
" max_exploration_rate=max_exploration_rate, exploration_decay_rate=exploration_decay_rate,\n",
" min_exploration_rate=min_exploration_rate)\n",
" reward=None\n",
" \n",
" \n",
" reward_list = agent.train(env, n_episodes=n_episodes, max_timesteps=n_timesteps,\n",
" checkpoint_path=\"./saves/checkpoints/\", warm_start_path=None,\n",
" verbatim=verbatim, render_every=None)\n",
" \n",
" #print(np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\n",
" plt.figure(figsize=(14,10))\n",
" plt.plot((np.arange(n_episodes) + 1), np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\n",
" plt.show()\n",
" env.close()\n",
" return (reward_list, agent, env)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2022-11-17T17:43:02.112305Z",
"start_time": "2022-11-17T17:43:01.349109Z"
},
"scrolled": false
},
"outputs": [
{
"ename": "IndexError",
"evalue": "list index out of range",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)",
"Input \u001b[1;32mIn [6]\u001b[0m, in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m map_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m4x4\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;66;03m#\"4x4\"\"8x8\"\u001b[39;00m\n\u001b[0;32m 3\u001b[0m is_slippery\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m rewards, agent, env \u001b[38;5;241m=\u001b[39m \u001b[43mfrozen_lake_train_nstep\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1e10\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdiscount_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.99\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exploration_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexploration_decay_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.001\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mmin_exploration_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.001\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmap_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_slippery\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_slippery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbatim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 9\u001b[0m agent\u001b[38;5;241m.\u001b[39msave(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./saves/nQagent_4x4F_last.npy\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 10\u001b[0m np\u001b[38;5;241m.\u001b[39mmean(rewards)\n",
"Input \u001b[1;32mIn [4]\u001b[0m, in \u001b[0;36mfrozen_lake_train_nstep\u001b[1;34m(n_episodes, n_timesteps, n_steps, discount_rate, lr, max_exploration_rate, exploration_decay_rate, min_exploration_rate, desc, map_name, is_slippery, verbatim)\u001b[0m\n\u001b[0;32m 13\u001b[0m agent \u001b[38;5;241m=\u001b[39m nstep_Q_agent(n_states\u001b[38;5;241m=\u001b[39mn_states, n_actions\u001b[38;5;241m=\u001b[39mn_actions, n_steps\u001b[38;5;241m=\u001b[39mn_steps,\n\u001b[0;32m 14\u001b[0m discount_rate\u001b[38;5;241m=\u001b[39mdiscount_rate, lr\u001b[38;5;241m=\u001b[39mlr,\n\u001b[0;32m 15\u001b[0m max_exploration_rate\u001b[38;5;241m=\u001b[39mmax_exploration_rate, exploration_decay_rate\u001b[38;5;241m=\u001b[39mexploration_decay_rate,\n\u001b[0;32m 16\u001b[0m min_exploration_rate\u001b[38;5;241m=\u001b[39mmin_exploration_rate)\n\u001b[0;32m 17\u001b[0m reward\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m reward_list \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_episodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_episodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./saves/checkpoints/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarm_start_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbatim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbatim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_every\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 24\u001b[0m \u001b[38;5;66;03m#print(np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\u001b[39;00m\n\u001b[0;32m 25\u001b[0m plt\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m14\u001b[39m,\u001b[38;5;241m10\u001b[39m))\n",
"File \u001b[1;32m~\\Desktop\\localcopy\\ReinforcementLearning\\Examples\\FrozenLake\\../..\\RLFramework\\Agents.py:139\u001b[0m, in \u001b[0;36mRLAgent.train\u001b[1;34m(self, env, n_episodes, max_timesteps, checkpoint_path, warm_start_path, verbatim, render_every)\u001b[0m\n\u001b[0;32m 137\u001b[0m new_state, reward, done, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[0;32m 138\u001b[0m \u001b[38;5;66;03m#self.store_experience(state, action, reward, new_state, done)\u001b[39;00m\n\u001b[1;32m--> 139\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreward\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimestep\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 140\u001b[0m state \u001b[38;5;241m=\u001b[39m new_state\n\u001b[0;32m 142\u001b[0m \u001b[38;5;66;03m# sum up the number of rewards after n episodes\u001b[39;00m\n",
"File \u001b[1;32m~\\Desktop\\localcopy\\ReinforcementLearning\\Examples\\FrozenLake\\../..\\RLFramework\\Agents.py:315\u001b[0m, in \u001b[0;36mnstep_Q_agent.update_policy\u001b[1;34m(self, old_state, new_state, reward, action, episode, done, t)\u001b[0m\n\u001b[0;32m 313\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_steps \u001b[38;5;241m-\u001b[39m t2):\n\u001b[0;32m 314\u001b[0m ki \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_steps\u001b[38;5;241m-\u001b[39mk\n\u001b[1;32m--> 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msmem\u001b[49m\u001b[43m[\u001b[49m\u001b[43mki\u001b[49m\u001b[43m]\u001b[49m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mamem[ki]]\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]):\n\u001b[0;32m 316\u001b[0m G \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrmem[ki] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdiscount_rate \u001b[38;5;241m*\u001b[39m G\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(np\u001b[38;5;241m.\u001b[39margwhere(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]))\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,))\n\u001b[0;32m 317\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n",
"\u001b[1;31mIndexError\u001b[0m: list index out of range"
]
}
],
"source": [
"desc=None\n",
"map_name=\"4x4\"#\"4x4\"\"8x8\"\n",
"is_slippery=False\n",
"\n",
"rewards, agent, env = frozen_lake_train_nstep(10000, n_timesteps=int(1e10), n_steps=5, discount_rate = 0.99, lr = 0.1,\n",
" max_exploration_rate = 1, exploration_decay_rate = 0.001,\n",
" min_exploration_rate = 0.001,\n",
" desc=desc, map_name=map_name, is_slippery=is_slippery, verbatim=0)\n",
"agent.save(\"./saves/nQagent_4x4F_last.npy\")\n",
"np.mean(rewards)"
]
}
],
"metadata": {
Expand Down Expand Up @@ -1519,9 +1607,9 @@
},
"position": {
"height": "677.587px",
"left": "1798.19px",
"left": "1739.18px",
"right": "20px",
"top": "116.969px",
"top": "120.962px",
"width": "338.264px"
},
"types_to_exclude": [
Expand Down
4 changes: 2 additions & 2 deletions RLFramework/Agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def train(self, env, n_episodes=10000, max_timesteps=500,
#print(self.greedy_eps)
new_state, reward, done, truncated, info = env.step(action)
#self.store_experience(state, action, reward, new_state, done)
self.update_policy(state, new_state, reward, action, episode, done=False, t=timestep)
self.update_policy(state, new_state, reward, action, episode, done=done, t=timestep)
state = new_state

# sum up the number of rewards after n episodes
Expand All @@ -145,7 +145,7 @@ def train(self, env, n_episodes=10000, max_timesteps=500,
#self.update_target_network(episode)
reward_list.append(total_reward)
if ((episode+1)%100==0) and (verbatim>0):
print(f"---- Episodes {episode-98} to {episode+1} finished in {(time.time() - start_time):.2f} seconds ----")
print(f"---- Episodes {episode-98} to {episode+1} finished in {(time.time() - start_time):.2f} sec. with ave. reward: {np.mean(reward_list[(-99):]):.2f}. ----")
start_time = time.time()
if checkpoint_path:
if ((episode+1)%1000==0):
Expand Down

0 comments on commit 82fc928

Please sign in to comment.