Skip to content

Commit

Permalink
Fix FlatObsWrapper obs dtype (#435)
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jun 6, 2024
1 parent 6b99e06 commit 6762cb1
Showing 1 changed file with 9 additions and 12 deletions.
21 changes: 9 additions & 12 deletions minigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,19 +569,17 @@ class FlatObsWrapper(ObservationWrapper):
(2835,)
"""

def __init__(self, env, maxStrLen=96):
def __init__(self, env, maxStrLen: int = 96):
super().__init__(env)

self.maxStrLen = maxStrLen
self.numCharCodes = 28

imgSpace = env.observation_space.spaces["image"]
imgSize = reduce(operator.mul, imgSpace.shape, 1)

img_size = np.prod(env.observation_space["image"].shape)
self.observation_space = spaces.Box(
low=0,
high=255,
shape=(imgSize + self.numCharCodes * self.maxStrLen,),
shape=(img_size + self.numCharCodes * self.maxStrLen,),
dtype="uint8",
)

Expand All @@ -598,12 +596,11 @@ def observation(self, obs):
), f"mission string too long ({len(mission)} chars)"
mission = mission.lower()

strArray = np.zeros(
shape=(self.maxStrLen, self.numCharCodes), dtype="float32"
)
str_array = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="uint8")
# as `numCharCodes` < 255 then we can use `uint8`

for idx, ch in enumerate(mission):
if ch >= "a" and ch <= "z":
if "a" <= ch <= "z":
chNo = ord(ch) - ord("a")
elif ch == " ":
chNo = ord("z") - ord("a") + 1
Expand All @@ -613,11 +610,11 @@ def observation(self, obs):
raise ValueError(
f"Character {ch} is not available in mission string."
)
assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
strArray[idx, chNo] = 1
assert chNo < self.numCharCodes, f"{ch} : {chNo:d}"
str_array[idx, chNo] = 1

self.cachedStr = mission
self.cachedArray = strArray
self.cachedArray = str_array

obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))

Expand Down

0 comments on commit 6762cb1

Please sign in to comment.