Skip to content

Commit

Permalink
Rename utilities to help.
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvardnmbu committed Mar 12, 2024
1 parent af5b2c4 commit 08302f4
Show file tree
Hide file tree
Showing 13 changed files with 55 additions and 132 deletions.
84 changes: 22 additions & 62 deletions breakout/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,30 @@
{
"cell_type": "markdown",
"source": [
"# Visualise the pre-trained agent in action\n",
"\n",
"Modify the path to the weights and run the notebook."
"# Example of visualising the agent's training history performance"
],
"metadata": {
"collapsed": false
},
"id": "b3d8465ecb86eca7"
},
{
"cell_type": "markdown",
"source": [
"MODEL : Path to the pre-trained model\n",
"METRICS : Path to the training history, or None"
],
"metadata": {
"collapsed": false
},
"id": "fff872a8189754af"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"WEIGHTS = './_output/weights-15000.pth'\n",
"METRICS = './_output/metrics.csv'"
"MODEL = './results/model.pth'\n",
"METRICS = None"
],
"metadata": {
"collapsed": false
Expand All @@ -34,11 +43,9 @@
"import gymnasium as gym\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from DQN import VisionDeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import graph # noqa\n",
"from utilities.visualisation.movie import movie # noqa"
"from help.visualisation.plot import graph # noqa\n",
"from help.visualisation.movie import movie # noqa"
],
"metadata": {
"collapsed": false
Expand All @@ -49,48 +56,7 @@
{
"cell_type": "markdown",
"source": [
"## Parameters"
],
"metadata": {
"collapsed": false
},
"id": "4dddd56883444fab"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"network = {\n",
" \"input_channels\": 4, \"outputs\": 4,\n",
" \"channels\": [32, 64, 64],\n",
" \"kernels\": [8, 4, 3],\n",
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
" \"strides\": [4, 2, 1],\n",
" \"nodes\": [],\n",
"}\n",
"optimizer = {\n",
" \"optimizer\": torch.optim.Adam,\n",
" \"lr\": 1e-5,\n",
" \"hyperparameters\": {}\n",
"}\n",
"shape = {\n",
" \"original\": (1, 1, 210, 160),\n",
" \"width\": slice(7, -7),\n",
" \"height\": slice(31, -17),\n",
" \"max_pooling\": 2,\n",
"}\n",
"skip = 4"
],
"metadata": {
"collapsed": false
},
"id": "16867687f37ddbca",
"execution_count": null
},
{
"cell_type": "markdown",
"source": [
"## Setup"
"## Loading the agent and environment"
],
"metadata": {
"collapsed": false
Expand All @@ -101,13 +67,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"value_agent = VisionDeepQ(\n",
" network=network, optimizer=optimizer, shape=shape,\n",
" exploration_rate=0.002,\n",
")\n",
"\n",
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
"value_agent.load_state_dict(weights)\n",
"agent = torch.load(MODEL, map_location=torch.device('cpu'))\n",
"\n",
"environment = gym.make('ALE/Breakout-v5', render_mode=\"rgb_array\",\n",
" obs_type=\"grayscale\", frameskip=1, repeat_action_probability=0.0)\n",
Expand All @@ -132,7 +92,7 @@
{
"cell_type": "markdown",
"source": [
"### Plotting the metrics from the csv-file created during training."
"### Training history (if specified)"
],
"metadata": {
"collapsed": false
Expand All @@ -143,7 +103,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"graph(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
"graph(METRICS, title=\"Breakout training history\", window=20) if METRICS else None\n",
"plt.show() if METRICS else None"
],
"metadata": {
Expand All @@ -155,7 +115,7 @@
{
"cell_type": "markdown",
"source": [
"### Creating and saving a gif of the agent in action. The gif will be saved to the given path."
"### In action"
],
"metadata": {
"collapsed": false
Expand All @@ -166,7 +126,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"movie(environment, value_agent, './_output/breakout.avi', fps=60)"
"movie(environment, agent, './results/breakout.mp4', fps=20)"
],
"metadata": {
"collapsed": false
Expand Down
4 changes: 2 additions & 2 deletions cart-pole/DQN.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@
"from DQN import DeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import plot # noqa\n",
"from utilities.visualisation.gif import gif2 # noqa"
"from help.visualisation.plot import plot # noqa\n",
"from help.visualisation.gif import gif2 # noqa"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions cart-pole/REINFORCE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@
"from REINFORCE import PolicyGradient\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import plot # noqa\n",
"from utilities.visualisation.gif import gif2 # noqa"
"from help.visualisation.plot import plot # noqa\n",
"from help.visualisation.gif import gif2 # noqa"
]
},
{
Expand Down
83 changes: 23 additions & 60 deletions enduro/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,30 @@
{
"cell_type": "markdown",
"source": [
"# Visualise the pre-trained agent in action\n",
"\n",
"Modify the path to the weights and run the notebook."
"# Example of visualising the agent's training history performance"
],
"metadata": {
"collapsed": false
},
"id": "b3d8465ecb86eca7"
},
{
"cell_type": "markdown",
"source": [
"MODEL : Path to the pre-trained model\n",
"METRICS : Path to the training history, or None"
],
"metadata": {
"collapsed": false
},
"id": "fbd0af6b11428abe"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"WEIGHTS = './_output/weights-0.pth'\n",
"METRICS = None #'./_output/metrics.csv'"
"MODEL = './results/model.pth'\n",
"METRICS = None"
],
"metadata": {
"collapsed": false
Expand All @@ -34,11 +43,11 @@
"import gymnasium as gym\n",
"import matplotlib.pyplot as plt\n",
"\n",
"from DQN import VisionDeepQ\n",
"from train import SKIP\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import visualise_csv # noqa\n",
"from utilities.visualisation.gif import gif # noqa"
"from help.visualisation.plot import graph2 # noqa\n",
"from help.visualisation.gif import gif # noqa"
],
"metadata": {
"collapsed": false
Expand All @@ -49,47 +58,7 @@
{
"cell_type": "markdown",
"source": [
"## Parameters"
],
"metadata": {
"collapsed": false
},
"id": "4dddd56883444fab"
},
{
"cell_type": "code",
"outputs": [],
"source": [
"network = {\n",
" \"input_channels\": 2, \"outputs\": 9,\n",
" \"channels\": [32, 64, 64],\n",
" \"kernels\": [8, 4, 3],\n",
" \"padding\": [\"valid\", \"valid\", \"valid\"],\n",
" \"strides\": [4, 2, 1],\n",
" \"nodes\": [512],\n",
"}\n",
"optimizer = {\n",
" \"optimizer\": torch.optim.RMSprop,\n",
" \"lr\": 0.0001,\n",
" \"hyperparameters\": {}\n",
"}\n",
"shape = {\n",
" \"original\": (1, 1, 210, 160),\n",
" \"height\": slice(51, 155),\n",
" \"width\": slice(8, 160)\n",
"}\n",
"skip = 4"
],
"metadata": {
"collapsed": false
},
"id": "16867687f37ddbca",
"execution_count": null
},
{
"cell_type": "markdown",
"source": [
"## Setup"
"## Loading the agent and environment"
],
"metadata": {
"collapsed": false
Expand All @@ -100,13 +69,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"value_agent = VisionDeepQ(\n",
" network=network, optimizer=optimizer, shape=shape,\n",
" exploration_rate=0.01,\n",
")\n",
"\n",
"weights = torch.load(WEIGHTS, map_location=torch.device('cpu'))\n",
"value_agent.load_state_dict(weights)\n",
"agent = torch.load(MODEL, map_location=torch.device('cpu'))\n",
"\n",
"environment = gym.make('ALE/Enduro-v5', render_mode=\"rgb_array\",\n",
" obs_type=\"grayscale\", frameskip=1, repeat_action_probability=0.0)\n",
Expand All @@ -131,7 +94,7 @@
{
"cell_type": "markdown",
"source": [
"### Plotting the metrics from the csv-file created during training."
"### Training history (if specified)"
],
"metadata": {
"collapsed": false
Expand All @@ -142,7 +105,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"visualise_csv(METRICS, title=\"Training history\", window=20) if METRICS else None\n",
"graph2(METRICS, title=\"Enduro training history\", window=20) if METRICS else None\n",
"plt.show() if METRICS else None"
],
"metadata": {
Expand All @@ -154,7 +117,7 @@
{
"cell_type": "markdown",
"source": [
"### Creating and saving a gif of the agent in action. The gif will be saved to the given path."
"### In action"
],
"metadata": {
"collapsed": false
Expand All @@ -165,7 +128,7 @@
"cell_type": "code",
"outputs": [],
"source": [
"gif(environment, value_agent, './_output/enduro-0.gif', skip, 25)"
"gif(environment, agent, './results/enduro.gif', SKIP, 25)"
],
"metadata": {
"collapsed": false
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def gif2(environment, agent, path="./live-preview.gif", duration=50):
_ = imageio.mimsave(path, images, duration=duration)


def gif(environment, agent, path="./live-preview.gif", skip=4, duration=50):
def gif(environment, agent, path="./live-preview.gif", skip=1, duration=50):
"""
Create a GIF of the agent playing the environment.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch


def movie(environment, agent, path="./live-preview.mp4", skip=4, fps=50):
def movie(environment, agent, path="./live-preview.mp4", skip=1, fps=50):
"""Created by Mistral Large."""
states = agent.preprocess(environment.reset()[0])
if hasattr(agent, "shape") and "reshape" in agent.shape:
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions tetris/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@
"from DQN import DeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import graph # noqa\n",
"from utilities.visualisation.gif import gif # noqa"
"from help.visualisation.plot import graph # noqa\n",
"from help.visualisation.gif import gif # noqa"
],
"metadata": {
"collapsed": false,
Expand Down
4 changes: 2 additions & 2 deletions tetris/transfer-learning/DQN-ResNet18.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
"from DQN import TransferDeepQ\n",
"\n",
"sys.path.append(\"../\")\n",
"from utilities.visualisation.plot import plot # noqa\n",
"from utilities.visualisation.gif import gif # noqa"
"from help.visualisation.plot import plot # noqa\n",
"from help.visualisation.gif import gif # noqa"
]
},
{
Expand Down

0 comments on commit 08302f4

Please sign in to comment.