Skip to content

Commit

Permalink
Merge pull request #158 from edbeeching/fix-image-decoding-in-godot_e…
Browse files Browse the repository at this point in the history
…nv.py

fix dtype of image data
  • Loading branch information
Ivan-267 authored Jan 17, 2024
2 parents 8eeef0a + a84160c commit 69495f4
Showing 1 changed file with 38 additions and 35 deletions.
73 changes: 38 additions & 35 deletions godot_rl/core/godot_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@
import socket
import subprocess
import time
from collections import OrderedDict
from sys import platform
from typing import Optional

import numpy as np
from gymnasium import spaces
from typing import Optional
from godot_rl.core.utils import ActionSpaceProcessor, convert_macos_path
from gymnasium import spaces

from collections import OrderedDict

class GodotEnv:
MAJOR_VERSION = "0" # Versioning for the environment
MAJOR_VERSION = "0" # Versioning for the environment
MINOR_VERSION = "4"
DEFAULT_PORT = 11008 # Default port for communication with Godot Game
DEFAULT_TIMEOUT = 60 # Default socket timeout TODO
DEFAULT_PORT = 11008 # Default port for communication with Godot Game
DEFAULT_TIMEOUT = 60 # Default socket timeout TODO

def __init__(
self,
env_path: str=None,
port: int=DEFAULT_PORT,
show_window: bool=False,
seed:int=0,
framerate:Optional[int]=None,
action_repeat:Optional[int]=None,
speedup:Optional[int]=None,
convert_action_space:bool=False,
self,
env_path: str = None,
port: int = DEFAULT_PORT,
show_window: bool = False,
seed: int = 0,
framerate: Optional[int] = None,
action_repeat: Optional[int] = None,
speedup: Optional[int] = None,
convert_action_space: bool = False,
):
"""
Initialize a new instance of GodotEnv
Expand Down Expand Up @@ -97,18 +98,18 @@ def check_platform(self, filename: str):
if platform == "linux" or platform == "linux2":
# Linux
assert (
pathlib.Path(filename).suffix == ".x86_64"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix }. Please provide a .x86_64 file"
pathlib.Path(filename).suffix == ".x86_64"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .x86_64 file"
elif platform == "darwin":
# OSX
assert (
pathlib.Path(filename).suffix == ".app"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix }. Please provide a .app file"
pathlib.Path(filename).suffix == ".app"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .app file"
elif platform == "win32":
# Windows...
assert (
pathlib.Path(filename).suffix == ".exe"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix }. Please provide a .exe file"
pathlib.Path(filename).suffix == ".exe"
), f"Incorrect file suffix for filename {filename} suffix {pathlib.Path(filename).suffix}. Please provide a .exe file"
else:
assert 0, f"unknown filetype {pathlib.Path(filename).suffix}"

Expand Down Expand Up @@ -158,7 +159,6 @@ def step(self, action, order_ij=False):
self.step_send(action, order_ij=order_ij)
return self.step_recv()


def step_send(self, action, order_ij=False):
"""
Send the action to the Godot environment.
Expand Down Expand Up @@ -192,8 +192,6 @@ def step_recv(self):
[{}] * len(response["done"]),
)



def _process_obs(self, response_obs: dict):
"""
Process observation data.
Expand Down Expand Up @@ -336,12 +334,20 @@ def _get_env_info(self):
print("observation space", json_dict["observation_space"])
for k, v in json_dict["observation_space"].items():
if v["space"] == "box":
observation_spaces[k] = spaces.Box(
low=-1.0,
high=1.0,
shape=v["size"],
dtype=np.float32,
)
if "2d" in k:
observation_spaces[k] = spaces.Box(
low=0,
high=255,
shape=v["size"],
dtype=np.uint8,
)
else:
observation_spaces[k] = spaces.Box(
low=-1.0,
high=1.0,
shape=v["size"],
dtype=np.float32,
)
elif v["space"] == "discrete":
observation_spaces[k] = spaces.Discrete(v["size"])
else:
Expand All @@ -351,17 +357,14 @@ def _get_env_info(self):

self.num_envs = json_dict["n_agents"]



@staticmethod
def _decode_2d_obs_from_string(
hex_string,
shape,
hex_string,
shape,
):
return (
np.frombuffer(bytes.fromhex(hex_string), dtype=np.float16)
np.frombuffer(bytes.fromhex(hex_string), dtype=np.uint8)
.reshape(shape)
.astype(np.float32)[:, :, :] # TODO remove the alpha channel
)

def _send_as_json(self, dictionary):
Expand Down

0 comments on commit 69495f4

Please sign in to comment.