Skip to content

Commit

Permalink
GitHub Actions fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
hallvardnmbu committed Mar 11, 2024
1 parent b7872b7 commit 0c9744e
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 16 deletions.
6 changes: 4 additions & 2 deletions reinforcement-learning/breakout/DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,7 +275,7 @@ def preprocess(self, state):

return state

def observe(self, environment, states, *args): # noqa
def observe(self, environment, states, skip=None):
"""
Observe the environment for n frames.
Expand All @@ -285,7 +285,7 @@ def observe(self, environment, states, *args): # noqa
The environment to observe.
states : torch.Tensor
The states of the environment from the previous step.
args
skip : int, optional
To be compatible with the other DQN agents. Added here instead of using ABC.
Returns
Expand All @@ -299,6 +299,8 @@ def observe(self, environment, states, *args): # noqa
done : bool
Whether the game is terminated.
"""
print("Warning: `skip` is not used in `VisionDeepQ.observe`.") if skip is not None else None

action = self.action(states)

done = False
Expand Down
24 changes: 10 additions & 14 deletions reinforcement-learning/utilities/visualisation/movie.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,25 @@
"""Create a movie of an agent interacting with an environment."""

import cv2
import torch


def create_movie(environment, agent, path, fps=60):
def create_movie(environment, agent, path="./live-preview.gif", skip=4, fps=50):
"""Created by Mistral Large."""
initial = agent.preprocess(environment.reset()[0])
try:
states = torch.cat([initial] * agent.shape["reshape"][1], dim=1)
except AttributeError:
states = initial

try:
done = False
done = False

# Get the dimensions of the first image
height, width, channels = environment.render().shape
height, width, _ = environment.render().shape
fourcc = cv2.VideoWriter_fourcc(*"MJPG") # noqa
movie = cv2.VideoWriter(path, fourcc, fps, (width, height))

# Create the VideoWriter object
fourcc = cv2.VideoWriter_fourcc(*"MJPG") # You can change the codec if needed
video_writer = cv2.VideoWriter(path, fourcc, fps, (width, height))
while not done:
_, states, _, done = agent.observe(environment, states)
video_writer.write(environment.render())
except Exception as e:
print(f"Error during image generation or writing: {e}")
return
while not done:
_, states, _, done = agent.observe(environment, states, skip)
movie.write(environment.render())

cv2.destroyAllWindows()
4 changes: 4 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@ torch
numpy
pandas
scipy

# For visualisation
# *-----------------------------------------*
matplotlib
imageio
opencv-python

# Library for utilising Apple M chips
# *-----------------------------------------*
Expand Down

0 comments on commit 0c9744e

Please sign in to comment.