diff --git a/Examples/CartPole/CartPoleFeatures/.pylint.d/main1.stats b/Examples/CartPole/CartPoleFeatures/.pylint.d/main1.stats new file mode 100644 index 0000000..bae71e9 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/.pylint.d/main1.stats differ diff --git a/Examples/CartPole/CartPoleFeatures/main.py b/Examples/CartPole/CartPoleFeatures/main.py index 5c91329..3642a41 100644 --- a/Examples/CartPole/CartPoleFeatures/main.py +++ b/Examples/CartPole/CartPoleFeatures/main.py @@ -108,11 +108,18 @@ def play(agent, environement='CartPole-v1', n_episodes=5, n_timesteps=1000, plot -agent, reward_list = train(environement='CartPole-v1', n_episodes=10000, n_timesteps=500, - Hidden_vect=[32,32], discount_rate = 0.999, lr = 1e-3, +agent, reward_list = train(environement='CartPole-v1', n_episodes=5000, n_timesteps=500, + Hidden_vect=[32,24], discount_rate = 0.999, lr = 1e-3, max_exploration_rate = 1, exploration_decay_rate = 0.001, min_exploration_rate = 0.01,#0.001, warm_start_path=None, render_mode="rgb_array") agent.save("model_weights/DQN_cartpole_weights_last.pt") replay_memory = agent.replay_memory + +#agent.load("model_weights/DQN_cartpole_weights_best_32_24.pt") + rew_play = play(agent, environement='CartPole-v1', n_episodes=5, n_timesteps=1000, plot_rewards=False, render_mode="human") print(rew_play) + + + + diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_last.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_last.pt deleted file mode 100644 index ff9b93b..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_last.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_1000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_1000.pt deleted file mode 100644 index ddcadb0..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_1000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_10000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_10000.pt deleted file mode 100644 index ff9b93b..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_10000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_2000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_2000.pt deleted file mode 100644 index b15c6e6..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_2000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_3000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_3000.pt deleted file mode 100644 index 23908f2..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_3000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_4000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_4000.pt deleted file mode 100644 index 84d383c..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_4000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_5000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_5000.pt deleted file mode 100644 index 762c691..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_5000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_6000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_6000.pt deleted file mode 100644 index 9954312..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_6000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_7000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_7000.pt deleted file mode 100644 index 58ced88..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_7000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_8000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_8000.pt deleted file mode 100644 index 83b62a8..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_8000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_9000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_9000.pt deleted file mode 100644 index ea3de92..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205/DQN_cartpole_weights_traintemp_9000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_last.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_last.pt deleted file mode 100644 index dd1b303..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_last.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_1000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_1000.pt deleted file mode 100644 index 3948a63..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_1000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_2000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_2000.pt deleted file mode 100644 index 0d4f0fc..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_2000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_3000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_3000.pt deleted file mode 100644 index e6da6fa..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_3000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_4000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_4000.pt deleted file mode 100644 index 6ee844c..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_4000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_5000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_5000.pt deleted file mode 100644 index 7d3723c..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_5000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_6000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_6000.pt deleted file mode 100644 index ead16de..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_6000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_7000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_7000.pt deleted file mode 100644 index 3f1957a..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_7000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_8000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_8000.pt deleted file mode 100644 index 1404bfe..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_8000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_9000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_9000.pt deleted file mode 100644 index 1163638..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_9000.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000_64_32.pt similarity index 76% rename from Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000.pt rename to Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000_64_32.pt index 43000ec..66a0b46 100644 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000.pt and b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_2000_64_32.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_10000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_best_32_24.pt similarity index 56% rename from Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_10000.pt rename to Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_best_32_24.pt index dd1b303..33da906 100644 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/20191205_2/DQN_cartpole_weights_traintemp_10000.pt and b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_best_32_24.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last.pt b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last.pt index 34e41eb..515ad3f 100644 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last.pt and b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last0.pt b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last0.pt deleted file mode 100644 index dd1b303..0000000 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_last0.pt and /dev/null differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights.pt b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_meh_64_32.pt similarity index 73% rename from Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights.pt rename to Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_meh_64_32.pt index 9a60a6e..8ec2206 100644 Binary files a/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights.pt and b/Examples/CartPole/CartPoleFeatures/model_weights/DQN_cartpole_weights_meh_64_32.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_1000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_1000.pt new file mode 100644 index 0000000..9b18041 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_1000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_10000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_10000.pt new file mode 100644 index 0000000..6a00efa Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_10000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_2000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_2000.pt new file mode 100644 index 0000000..5dbbcb2 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_2000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_3000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_3000.pt new file mode 100644 index 0000000..c25676d Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_3000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_4000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_4000.pt new file mode 100644 index 0000000..178fcaa Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_4000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_5000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_5000.pt new file mode 100644 index 0000000..b515427 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_5000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_6000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_6000.pt new file mode 100644 index 0000000..1a29ae9 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_6000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_7000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_7000.pt new file mode 100644 index 0000000..8b0d984 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_7000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_8000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_8000.pt new file mode 100644 index 0000000..ab15751 Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_8000.pt differ diff --git a/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_9000.pt b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_9000.pt new file mode 100644 index 0000000..8c764cf Binary files /dev/null and b/Examples/CartPole/CartPoleFeatures/model_weights/checkpoints_last/training_checkpoint_9000.pt differ diff --git a/Examples/FrozenLake/reinforcement-learning-frozen-lake.ipynb b/Examples/FrozenLake/reinforcement-learning-frozen-lake.ipynb index 896a96d..4e1076e 100644 --- a/Examples/FrozenLake/reinforcement-learning-frozen-lake.ipynb +++ b/Examples/FrozenLake/reinforcement-learning-frozen-lake.ipynb @@ -12,8 +12,8 @@ "execution_count": 1, "metadata": { "ExecuteTime": { - "end_time": "2022-11-13T17:49:00.859832Z", - "start_time": "2022-11-13T17:48:59.118357Z" + "end_time": "2022-11-17T17:47:07.528193Z", + "start_time": "2022-11-17T17:47:06.661799Z" } }, "outputs": [], @@ -32,8 +32,8 @@ "execution_count": 2, "metadata": { "ExecuteTime": { - "end_time": "2022-11-13T17:49:03.414396Z", - "start_time": "2022-11-13T17:49:00.971074Z" + "end_time": "2022-11-17T17:47:09.057967Z", + "start_time": "2022-11-17T17:47:08.038928Z" } }, "outputs": [], @@ -54,8 +54,8 @@ "execution_count": 3, "metadata": { "ExecuteTime": { - "end_time": "2022-11-13T17:49:04.372662Z", - "start_time": "2022-11-13T17:49:04.344113Z" + "end_time": "2022-11-17T17:47:12.755822Z", + "start_time": "2022-11-17T17:47:12.727898Z" } }, "outputs": [], @@ -96,8 +96,8 @@ "execution_count": 4, "metadata": { "ExecuteTime": { - "end_time": "2022-11-13T17:49:05.590320Z", - "start_time": "2022-11-13T17:49:05.582172Z" + "end_time": "2022-11-17T17:47:14.973047Z", + "start_time": "2022-11-17T17:47:14.948087Z" } }, "outputs": [], @@ -1438,6 +1438,94 @@ "rewards3, agent3, env3 = frozen_lake_play(env3, agent3, 4,\n", " desc=desc3, map_name=map_name3, is_slippery=is_slippery3)" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# n-step Q-learning" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-17T17:42:58.385626Z", + "start_time": "2022-11-17T17:42:58.372627Z" + } + }, + "outputs": [], + "source": [ + "def frozen_lake_train_nstep(n_episodes, n_timesteps=int(1e10), n_steps=5,\n", + " discount_rate = 0.99, lr = 0.1,\n", + " max_exploration_rate = 1, exploration_decay_rate = 0.001, min_exploration_rate = 0.001,\n", + " desc=None, map_name=\"4x4\", is_slippery=True, verbatim=0):\n", + " \n", + " #create environment frozen lake\n", + " env = gym.make(\"FrozenLake-v1\", desc=desc, map_name=map_name, is_slippery=is_slippery, render_mode=\"rgb_array_list\")\n", + " \n", + " n_states = env.observation_space.n\n", + " n_actions = env.action_space.n\n", + " \n", + " #create agent\n", + " agent = nstep_Q_agent(n_states=n_states, n_actions=n_actions, n_steps=n_steps,\n", + " discount_rate=discount_rate, lr=lr,\n", + " max_exploration_rate=max_exploration_rate, exploration_decay_rate=exploration_decay_rate,\n", + " min_exploration_rate=min_exploration_rate)\n", + " reward=None\n", + " \n", + " \n", + " reward_list = agent.train(env, n_episodes=n_episodes, max_timesteps=n_timesteps,\n", + " checkpoint_path=\"./saves/checkpoints/\", warm_start_path=None,\n", + " verbatim=verbatim, render_every=None)\n", + " \n", + " #print(np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\n", + " plt.figure(figsize=(14,10))\n", + " plt.plot((np.arange(n_episodes) + 1), np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\n", + " plt.show()\n", + " env.close()\n", + " return (reward_list, agent, env)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "ExecuteTime": { + "end_time": "2022-11-17T17:43:02.112305Z", + "start_time": "2022-11-17T17:43:01.349109Z" + }, + "scrolled": false + }, + "outputs": [ + { + "ename": "IndexError", + "evalue": "list index out of range", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mIndexError\u001b[0m Traceback (most recent call last)", + "Input \u001b[1;32mIn [6]\u001b[0m, in \u001b[0;36m\u001b[1;34m\u001b[0m\n\u001b[0;32m 2\u001b[0m map_name\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m4x4\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;66;03m#\"4x4\"\"8x8\"\u001b[39;00m\n\u001b[0;32m 3\u001b[0m is_slippery\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m\n\u001b[1;32m----> 5\u001b[0m rewards, agent, env \u001b[38;5;241m=\u001b[39m \u001b[43mfrozen_lake_train_nstep\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m10000\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1e10\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_steps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m5\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdiscount_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.99\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mlr\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.1\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 6\u001b[0m \u001b[43m \u001b[49m\u001b[43mmax_exploration_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mexploration_decay_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.001\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 7\u001b[0m \u001b[43m \u001b[49m\u001b[43mmin_exploration_rate\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43m \u001b[49m\u001b[38;5;241;43m0.001\u001b[39;49m\u001b[43m,\u001b[49m\n\u001b[0;32m 8\u001b[0m \u001b[43m \u001b[49m\u001b[43mdesc\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdesc\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmap_name\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mmap_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mis_slippery\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mis_slippery\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mverbatim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m0\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m 9\u001b[0m agent\u001b[38;5;241m.\u001b[39msave(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./saves/nQagent_4x4F_last.npy\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 10\u001b[0m np\u001b[38;5;241m.\u001b[39mmean(rewards)\n", + "Input \u001b[1;32mIn [4]\u001b[0m, in \u001b[0;36mfrozen_lake_train_nstep\u001b[1;34m(n_episodes, n_timesteps, n_steps, discount_rate, lr, max_exploration_rate, exploration_decay_rate, min_exploration_rate, desc, map_name, is_slippery, verbatim)\u001b[0m\n\u001b[0;32m 13\u001b[0m agent \u001b[38;5;241m=\u001b[39m nstep_Q_agent(n_states\u001b[38;5;241m=\u001b[39mn_states, n_actions\u001b[38;5;241m=\u001b[39mn_actions, n_steps\u001b[38;5;241m=\u001b[39mn_steps,\n\u001b[0;32m 14\u001b[0m discount_rate\u001b[38;5;241m=\u001b[39mdiscount_rate, lr\u001b[38;5;241m=\u001b[39mlr,\n\u001b[0;32m 15\u001b[0m max_exploration_rate\u001b[38;5;241m=\u001b[39mmax_exploration_rate, exploration_decay_rate\u001b[38;5;241m=\u001b[39mexploration_decay_rate,\n\u001b[0;32m 16\u001b[0m min_exploration_rate\u001b[38;5;241m=\u001b[39mmin_exploration_rate)\n\u001b[0;32m 17\u001b[0m reward\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[1;32m---> 20\u001b[0m reward_list \u001b[38;5;241m=\u001b[39m \u001b[43magent\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43menv\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mn_episodes\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_episodes\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mmax_timesteps\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mn_timesteps\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 21\u001b[0m \u001b[43m \u001b[49m\u001b[43mcheckpoint_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m./saves/checkpoints/\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mwarm_start_path\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m,\u001b[49m\n\u001b[0;32m 22\u001b[0m \u001b[43m \u001b[49m\u001b[43mverbatim\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mverbatim\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mrender_every\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mNone\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[0;32m 24\u001b[0m \u001b[38;5;66;03m#print(np.cumsum(reward_list)/(np.arange(n_episodes) + 1))\u001b[39;00m\n\u001b[0;32m 25\u001b[0m plt\u001b[38;5;241m.\u001b[39mfigure(figsize\u001b[38;5;241m=\u001b[39m(\u001b[38;5;241m14\u001b[39m,\u001b[38;5;241m10\u001b[39m))\n", + "File \u001b[1;32m~\\Desktop\\localcopy\\ReinforcementLearning\\Examples\\FrozenLake\\../..\\RLFramework\\Agents.py:139\u001b[0m, in \u001b[0;36mRLAgent.train\u001b[1;34m(self, env, n_episodes, max_timesteps, checkpoint_path, warm_start_path, verbatim, render_every)\u001b[0m\n\u001b[0;32m 137\u001b[0m new_state, reward, done, truncated, info \u001b[38;5;241m=\u001b[39m env\u001b[38;5;241m.\u001b[39mstep(action)\n\u001b[0;32m 138\u001b[0m \u001b[38;5;66;03m#self.store_experience(state, action, reward, new_state, done)\u001b[39;00m\n\u001b[1;32m--> 139\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mupdate_policy\u001b[49m\u001b[43m(\u001b[49m\u001b[43mstate\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnew_state\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreward\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maction\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mepisode\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdone\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdone\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtimestep\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 140\u001b[0m state \u001b[38;5;241m=\u001b[39m new_state\n\u001b[0;32m 142\u001b[0m \u001b[38;5;66;03m# sum up the number of rewards after n episodes\u001b[39;00m\n", + "File \u001b[1;32m~\\Desktop\\localcopy\\ReinforcementLearning\\Examples\\FrozenLake\\../..\\RLFramework\\Agents.py:315\u001b[0m, in \u001b[0;36mnstep_Q_agent.update_policy\u001b[1;34m(self, old_state, new_state, reward, action, episode, done, t)\u001b[0m\n\u001b[0;32m 313\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mrange\u001b[39m(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_steps \u001b[38;5;241m-\u001b[39m t2):\n\u001b[0;32m 314\u001b[0m ki \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mn_steps\u001b[38;5;241m-\u001b[39mk\n\u001b[1;32m--> 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43msmem\u001b[49m\u001b[43m[\u001b[49m\u001b[43mki\u001b[49m\u001b[43m]\u001b[49m, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mamem[ki]]\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]):\n\u001b[0;32m 316\u001b[0m G \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mrmem[ki] \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdiscount_rate \u001b[38;5;241m*\u001b[39m G\u001b[38;5;241m/\u001b[39m\u001b[38;5;28mlen\u001b[39m(np\u001b[38;5;241m.\u001b[39margwhere(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]\u001b[38;5;241m==\u001b[39mnp\u001b[38;5;241m.\u001b[39mmax(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mQ[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msmem[ki], :]))\u001b[38;5;241m.\u001b[39mreshape(\u001b[38;5;241m-\u001b[39m\u001b[38;5;241m1\u001b[39m,))\n\u001b[0;32m 317\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[1;31mIndexError\u001b[0m: list index out of range" + ] + } + ], + "source": [ + "desc=None\n", + "map_name=\"4x4\"#\"4x4\"\"8x8\"\n", + "is_slippery=False\n", + "\n", + "rewards, agent, env = frozen_lake_train_nstep(10000, n_timesteps=int(1e10), n_steps=5, discount_rate = 0.99, lr = 0.1,\n", + " max_exploration_rate = 1, exploration_decay_rate = 0.001,\n", + " min_exploration_rate = 0.001,\n", + " desc=desc, map_name=map_name, is_slippery=is_slippery, verbatim=0)\n", + "agent.save(\"./saves/nQagent_4x4F_last.npy\")\n", + "np.mean(rewards)" + ] } ], "metadata": { @@ -1519,9 +1607,9 @@ }, "position": { "height": "677.587px", - "left": "1798.19px", + "left": "1739.18px", "right": "20px", - "top": "116.969px", + "top": "120.962px", "width": "338.264px" }, "types_to_exclude": [ diff --git a/RLFramework/Agents.py b/RLFramework/Agents.py index 898a74a..17f78ec 100644 --- a/RLFramework/Agents.py +++ b/RLFramework/Agents.py @@ -136,7 +136,7 @@ def train(self, env, n_episodes=10000, max_timesteps=500, #print(self.greedy_eps) new_state, reward, done, truncated, info = env.step(action) #self.store_experience(state, action, reward, new_state, done) - self.update_policy(state, new_state, reward, action, episode, done=False, t=timestep) + self.update_policy(state, new_state, reward, action, episode, done=done, t=timestep) state = new_state # sum up the number of rewards after n episodes @@ -145,7 +145,7 @@ def train(self, env, n_episodes=10000, max_timesteps=500, #self.update_target_network(episode) reward_list.append(total_reward) if ((episode+1)%100==0) and (verbatim>0): - print(f"---- Episodes {episode-98} to {episode+1} finished in {(time.time() - start_time):.2f} seconds ----") + print(f"---- Episodes {episode-98} to {episode+1} finished in {(time.time() - start_time):.2f} sec. with ave. reward: {np.mean(reward_list[(-99):]):.2f}. ----") start_time = time.time() if checkpoint_path: if ((episode+1)%1000==0):