Skip to content

Commit

Permalink
Updated Enduro and Breakout algorithm. Waiting for Orion weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvardnmbu committed Mar 11, 2024
1 parent 3779353 commit b7872b7
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 37 deletions.
4 changes: 3 additions & 1 deletion reinforcement-learning/breakout/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def preprocess(self, state):

return state

def observe(self, environment, states):
def observe(self, environment, states, *args): # noqa
"""
Observe the environment for n frames.
Expand All @@ -285,6 +285,8 @@ def observe(self, environment, states):
The environment to observe.
states : torch.Tensor
The states of the environment from the previous step.
args
To be compatible with the other DQN agents. Added here instead of using ABC.
Returns
-------
Expand Down
23 changes: 14 additions & 9 deletions reinforcement-learning/enduro/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,24 +277,29 @@ def observe(self, environment, states, skip=1):
The action taken.
states : torch.Tensor
The states of the environment.
rewards : torch.Tensor
rewards : float
The rewards of the environment.
done : bool
Whether the game is terminated.
"""
action = self.action(states)

done = False
rewards = torch.tensor([0.0])
states = torch.zeros((1, skip, *self.shape["reshape"][2:4]))
rewards = 0.0
states = torch.zeros(self.shape["reshape"])

for i in range(0, skip):
new_state, reward, terminated, truncated, _ = environment.step(action.item())
done = (terminated or truncated) if not done else done
rewards += reward
for i in range(0, self.shape["reshape"][1]):

states[0, i] = self.preprocess(new_state)
states = torch.max(states, dim=1, keepdim=True).values
new_states = torch.zeros((1, skip, *self.shape["reshape"][2:4]))

for j in range(skip):
new_state, reward, terminated, truncated, _ = environment.step(action.item())
done = (terminated or truncated) if not done else done
rewards += reward

new_states[0, j] = self.preprocess(new_state)

states[0, i] = torch.max(new_states, dim=1, keepdim=True).values

return action, states, rewards, done

Expand Down
52 changes: 26 additions & 26 deletions reinforcement-learning/enduro/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import glob
import copy
import time
import random
import logging

import torch
Expand Down Expand Up @@ -47,14 +46,13 @@
# EXPLORATION_STEPS : number of games to decay exploration rate from `RATE` to `MIN`
# MINIBATCH : size of the minibatch
# TRAIN_EVERY : train the network every n games
# START_TRAINING_AT : start training after n games
# REMEMBER : only remember games with rewards, and this fraction of the games without
# MEMORY : size of the agents internal memory
# RESET_Q_EVERY : update target-network every n games

GAMES = 100000
SKIP = 6
CHECKPOINT = 5000
GAMES = 500
SKIP = 2
CHECKPOINT = 100

SHAPE = {
"original": (1, 1, 210, 160),
Expand All @@ -64,25 +62,24 @@

DISCOUNT = 0.95
GAMMA = 0.99
GRADIENTS = (-1, 1)
GRADIENTS = (-10, 10)

PUNISHMENT = 0
PUNISHMENT = -10
INCENTIVE = 1

MINIBATCH = 1
TRAIN_EVERY = 2
START_TRAINING_AT = 1000
MINIBATCH = 5
TRAIN_EVERY = 1

EXPLORATION_RATE = 0.9
EXPLORATION_RATE = 0.5
EXPLORATION_MIN = 0.01
EXPLORATION_STEPS = 30000 // TRAIN_EVERY
EXPLORATION_STEPS = 200 // TRAIN_EVERY

REMEMBER = 0.005
MEMORY = 500
RESET_Q_EVERY = TRAIN_EVERY * 1000
REMEMBER_FIRST = True
MEMORY = 5
RESET_Q_EVERY = TRAIN_EVERY * 5

NETWORK = {
"input_channels": 1, "outputs": 8,
"input_channels": 2, "outputs": 9,
"channels": [32, 64, 64],
"kernels": [8, 4, 3],
"padding": ["valid", "valid", "valid"],
Expand Down Expand Up @@ -144,33 +141,36 @@

TRAINING = False
_STEPS = _LOSS = _REWARD = 0
STEP = SKIP * value_agent.shape["reshape"][1]
for game in range(1, GAMES + 1):

initial = value_agent.preprocess(environment.reset()[0])
states = torch.cat([initial] * value_agent.shape["reshape"][1], dim=1)

DONE = False
STEPS = REWARDS = 0
TRAINING = True if (not TRAINING and game >= START_TRAINING_AT) else TRAINING
STEPS = 0
REWARDS = []
TRAINING = True if (not TRAINING and len(value_agent.memory["memory"]) > 0) else TRAINING
while not DONE:
action, new_states, rewards, DONE = value_agent.observe(environment, states, SKIP)
value_agent.remember(states, action, rewards)
value_agent.remember(states, action, torch.tensor(rewards))

states = new_states
REWARDS += rewards.item()
STEPS += 1
REWARDS.append(rewards)
STEPS += STEP

if random.random() < REMEMBER or REWARDS > 0:
if len(REWARDS) > 0 or REMEMBER_FIRST:
REMEMBER_FIRST = False
value_agent.memorize(states, STEPS)
logger.debug(" %s --> (%s) %s", game, int(STEPS), int(REWARDS))
logger.debug(" %s --> (%s) %s", game, int(STEPS), len(REWARDS))
value_agent.memory["game"].clear()

LOSS = None
if game % TRAIN_EVERY == 0 and len(value_agent.memory["memory"]) > 0 and TRAINING:
if game % TRAIN_EVERY == 0 and TRAINING:
LOSS = value_agent.learn(network=_value_agent, clamp=GRADIENTS)
EXPLORATION_RATE = value_agent.parameter["rate"]
_LOSS += LOSS
_REWARD += REWARDS
_REWARD += sum(REWARDS)
_STEPS += STEPS

if game % RESET_Q_EVERY == 0 and TRAINING:
Expand All @@ -184,7 +184,7 @@

with open(METRICS, "a", newline="", encoding="UTF-8") as file:
metric = csv.writer(file)
metric.writerow([game, STEPS, LOSS, EXPLORATION_RATE, REWARDS])
metric.writerow([game, STEPS, LOSS, EXPLORATION_RATE, len(REWARDS)])

if game % (CHECKPOINT // 2) == 0 or game == GAMES:
logger.info("Game %s (progress %s %%, random %s %%)",
Expand Down
2 changes: 1 addition & 1 deletion reinforcement-learning/utilities/visualisation/gif.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def gif_stacked(environment, agent, path="./live-preview.gif", skip=4, duration=
images = []
done = False
while not done:
_, states, _, done = agent.observe(environment, states)
_, states, _, done = agent.observe(environment, states, skip)

images.append(environment.render())
_ = imageio.mimsave(path, images, duration=duration)

0 comments on commit b7872b7

Please sign in to comment.