Skip to content

Commit 6eab54b

Browse files
committed
fix merging
1 parent 4c44ec2 commit 6eab54b

File tree

2 files changed

+2
-103
lines changed

2 files changed

+2
-103
lines changed

minigrid/wrappers.py

+1-102
Original file line numberDiff line numberDiff line change
@@ -881,106 +881,6 @@ def step(self, action):
881881
reward += self.death_cost
882882

883883
return obs, reward, terminated, truncated, info
884-
885-
class StochasticActionWrapper(ActionWrapper):
886-
"""
887-
Add stochasticity to the actions
888-
889-
If a random action is provided, it is returned with probability `1 - prob`.
890-
Else, a random action is sampled from the action space.
891-
"""
892-
893-
def __init__(self, env=None, prob=0.9, random_action=None):
894-
super().__init__(env)
895-
self.prob = prob
896-
self.random_action = random_action
897-
898-
def action(self, action):
899-
""" """
900-
if np.random.uniform() < self.prob:
901-
return action
902-
else:
903-
if self.random_action is None:
904-
return self.np_random.integers(0, high=6)
905-
else:
906-
return self.random_action
907-
908-
909-
class NoDeath(Wrapper):
910-
"""
911-
Wrapper to prevent death in specific cells (e.g., lava cells).
912-
Instead of dying, the agent will receive a negative reward.
913-
914-
Example:
915-
>>> import gymnasium as gym
916-
>>> from minigrid.wrappers import NoDeath
917-
>>>
918-
>>> env = gym.make("MiniGrid-LavaCrossingS9N1-v0")
919-
>>> _, _ = env.reset(seed=2)
920-
>>> _, _, _, _, _ = env.step(1)
921-
>>> _, reward, term, *_ = env.step(2)
922-
>>> reward, term
923-
(0, True)
924-
>>>
925-
>>> env = NoDeath(env, no_death_types=("lava",), death_cost=-1.0)
926-
>>> _, _ = env.reset(seed=2)
927-
>>> _, _, _, _, _ = env.step(1)
928-
>>> _, reward, term, *_ = env.step(2)
929-
>>> reward, term
930-
(-1.0, False)
931-
>>>
932-
>>>
933-
>>> env = gym.make("MiniGrid-Dynamic-Obstacles-5x5-v0")
934-
>>> _, _ = env.reset(seed=2)
935-
>>> _, reward, term, *_ = env.step(2)
936-
>>> reward, term
937-
(-1, True)
938-
>>>
939-
>>> env = NoDeath(env, no_death_types=("ball",), death_cost=-1.0)
940-
>>> _, _ = env.reset(seed=2)
941-
>>> _, reward, term, *_ = env.step(2)
942-
>>> reward, term
943-
(-2.0, False)
944-
"""
945-
946-
def __init__(self, env, no_death_types: tuple[str, ...], death_cost: float = -1.0):
947-
"""A wrapper to prevent death in specific cells.
948-
949-
Args:
950-
env: The environment to apply the wrapper
951-
no_death_types: List of strings to identify death cells
952-
death_cost: The negative reward received in death cells
953-
954-
"""
955-
assert "goal" not in no_death_types, "goal cannot be a death cell"
956-
957-
super().__init__(env)
958-
self.death_cost = death_cost
959-
self.no_death_types = no_death_types
960-
961-
def step(self, action):
962-
# In Dynamic-Obstacles, obstacles move after the agent moves,
963-
# so we need to check for collision before self.env.step()
964-
front_cell = self.unwrapped.grid.get(*self.unwrapped.front_pos)
965-
going_to_death = (
966-
action == self.unwrapped.actions.forward
967-
and front_cell is not None
968-
and front_cell.type in self.no_death_types
969-
)
970-
971-
obs, reward, terminated, truncated, info = self.env.step(action)
972-
973-
# We also check if the agent stays in death cells (e.g., lava)
974-
# without moving
975-
current_cell = self.unwrapped.grid.get(*self.unwrapped.agent_pos)
976-
in_death = current_cell is not None and current_cell.type in self.no_death_types
977-
978-
if terminated and (going_to_death or in_death):
979-
terminated = False
980-
reward += self.death_cost
981-
982-
return obs, reward, terminated, truncated, info
983-
984884

985885
class MoveActionWrapper(Wrapper):
986886
"""
@@ -1010,8 +910,7 @@ def step(self, action):
1010910
else:
1011911
for _ in range(left_turns):
1012912
self.env.step(0)
1013-
913+
1014914
return self.env.step(2)
1015915
else:
1016916
return self.env.step(action - 1)
1017-

tests/test_wrappers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616
FlatObsWrapper,
1717
FullyObsWrapper,
1818
ImgObsWrapper,
19-
NoDeath,
2019
MoveActionWrapper,
20+
NoDeath,
2121
OneHotPartialObsWrapper,
2222
PositionBonus,
2323
ReseedWrapper,

0 commit comments

Comments
 (0)