Skip to content

[Bug Report] FlatObsWrapper lets to Crashes Due to Concatenation of uint8 and float32 Arrays in observation Method #434

@MarcSpeckmann

Description

@MarcSpeckmann

If you are submitting a bug report, please fill in the following details and use the tag [bug].

Describe the bug
The minigrid.wrappers.The FlatObsWrapper class in the observation method concatenates a uint8 array and a float32 array. This leads to RLlib crashes due to the mix of different dtypes.

Code example

System Info
Describe the characteristics of your environment:

  • MiniGrid was installed via pip inside a conda env
  • CentOS Linux release 7.6.1810 (Core)
  • python=3.10.14
torch==2.2.0
ray[all]==2.23.0
minigrid==2.3.1
networkx==3.2.1
gputil==1.4.0
hydra_core==1.3.2
hydra-callbacks==0.5.1

Additional context

def observation(self, obs):
image = obs["image"]
mission = obs["mission"]
# Cache the last-encoded mission string
if mission != self.cachedStr:
assert (
len(mission) <= self.maxStrLen
), f"mission string too long ({len(mission)} chars)"
mission = mission.lower()
strArray = np.zeros(
shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
)
for idx, ch in enumerate(mission):
if ch >= "a" and ch <= "z":
chNo = ord(ch) - ord("a")
elif ch == " ":
chNo = ord("z") - ord("a") + 1
elif ch == ",":
chNo = ord("z") - ord("a") + 2
else:
raise ValueError(
f"Character {ch} is not available in mission string."
)
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
strArray[idx, chNo] = 1
self.cachedStr = mission
self.cachedArray = strArray
obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
return obs

Checklist

  • [ x] I have checked that there is no similar issue in the repo (required)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions