-
-
Notifications
You must be signed in to change notification settings - Fork 635
Open
Description
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug
The minigrid.wrappers.The FlatObsWrapper class in the observation method concatenates a uint8 array and a float32 array. This leads to RLlib crashes due to the mix of different dtypes.
Code example
System Info
Describe the characteristics of your environment:
- MiniGrid was installed via pip inside a conda env
- CentOS Linux release 7.6.1810 (Core)
- python=3.10.14
torch==2.2.0
ray[all]==2.23.0
minigrid==2.3.1
networkx==3.2.1
gputil==1.4.0
hydra_core==1.3.2
hydra-callbacks==0.5.1
Additional context
Lines 590 to 624 in 37f28b2
| def observation(self, obs): | |
| image = obs["image"] | |
| mission = obs["mission"] | |
| # Cache the last-encoded mission string | |
| if mission != self.cachedStr: | |
| assert ( | |
| len(mission) <= self.maxStrLen | |
| ), f"mission string too long ({len(mission)} chars)" | |
| mission = mission.lower() | |
| strArray = np.zeros( | |
| shape=(self.maxStrLen, self.numCharCodes), dtype="float32" | |
| ) | |
| for idx, ch in enumerate(mission): | |
| if ch >= "a" and ch <= "z": | |
| chNo = ord(ch) - ord("a") | |
| elif ch == " ": | |
| chNo = ord("z") - ord("a") + 1 | |
| elif ch == ",": | |
| chNo = ord("z") - ord("a") + 2 | |
| else: | |
| raise ValueError( | |
| f"Character {ch} is not available in mission string." | |
| ) | |
| assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo) | |
| strArray[idx, chNo] = 1 | |
| self.cachedStr = mission | |
| self.cachedArray = strArray | |
| obs = np.concatenate((image.flatten(), self.cachedArray.flatten())) | |
| return obs |
Checklist
- [ x] I have checked that there is no similar issue in the repo (required)
Metadata
Metadata
Assignees
Labels
No labels