Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Report] FlatObsWrapper lets to Crashes Due to Concatenation of uint8 and float32 Arrays in observation Method #434

Open
MarcSpeckmann opened this issue May 29, 2024 · 2 comments

Comments

@MarcSpeckmann
Copy link
Contributor

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
The minigrid.wrappers.The FlatObsWrapper class in the observation method concatenates a uint8 array and a float32 array. This leads to RLlib crashes due to the mix of different dtypes.

Code example

System Info
Describe the characteristics of your environment:

  • MiniGrid was installed via pip inside a conda env
  • CentOS Linux release 7.6.1810 (Core)
  • python=3.10.14
torch==2.2.0
ray[all]==2.23.0
minigrid==2.3.1
networkx==3.2.1
gputil==1.4.0
hydra_core==1.3.2
hydra-callbacks==0.5.1

Additional context

def observation(self, obs):
image = obs["image"]
mission = obs["mission"]
# Cache the last-encoded mission string
if mission != self.cachedStr:
assert (
len(mission) <= self.maxStrLen
), f"mission string too long ({len(mission)} chars)"
mission = mission.lower()
strArray = np.zeros(
shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
)
for idx, ch in enumerate(mission):
if ch >= "a" and ch <= "z":
chNo = ord(ch) - ord("a")
elif ch == " ":
chNo = ord("z") - ord("a") + 1
elif ch == ",":
chNo = ord("z") - ord("a") + 2
else:
raise ValueError(
f"Character {ch} is not available in mission string."
)
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
strArray[idx, chNo] = 1
self.cachedStr = mission
self.cachedArray = strArray
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
return obs

Checklist

  • [ x] I have checked that there is no similar issue in the repo (required)
@pseudo-rnd-thoughts
Copy link
Member

Could you provide a script to recreate the error? If possible without RLLib, just the environment and wrapper stack with the error on reset or step

@MarcSpeckmann
Copy link
Contributor Author

MarcSpeckmann commented May 30, 2024

Certainly! Below is a simplified example script.

import gymnasium
import minigrid
from minigrid.wrappers import FlatObsWrapper

env = FlatObsWrapper(gymnasium.make("MiniGrid-Empty-5x5-v0"))
observation, info = env.reset(seed=42)
print(env.observation_space.dtype)
print(observation.dtype)
assert env.observation_space.dtype == observation.dtype
uint8
float32
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[8], [line 9](vscode-notebook-cell:?execution_count=8&line=9)
      [7](vscode-notebook-cell:?execution_count=8&line=7) print(env.observation_space.dtype)
      [8](vscode-notebook-cell:?execution_count=8&line=8) print(observation.dtype)
----> [9](vscode-notebook-cell:?execution_count=8&line=9) assert env.observation_space.dtype == observation.dtype

AssertionError:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants