Skip to content

Commit

Permalink
Save latest auto rendered frame as cache (Farama-Foundation#1010)
Browse files Browse the repository at this point in the history
  • Loading branch information
RogerJL committed Jun 10, 2024
1 parent 18daeec commit 583b3ba
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 12 deletions.
9 changes: 7 additions & 2 deletions gymnasium/wrappers/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def __init__(self, env: gym.Env[ObsType, ActType], auto_rendering: bool = True):
self.window = None # Has to be initialized before asserts, as self.window is used in auto close
self.clock = None
self.auto_rendering = auto_rendering
self._latest_frame = None

assert (
self.env.render_mode in self.ACCEPTED_RENDER_MODES
Expand All @@ -498,7 +499,7 @@ def step(self, action: ActType) -> tuple[ObsType, SupportsFloat, bool, bool, dic
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(action)
if self.auto_rendering:
self._render_frame()
self._latest_frame = self._render_frame()
return result

def reset(
Expand All @@ -507,11 +508,15 @@ def reset(
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
if self.auto_rendering:
self._render_frame()
self._latest_frame = self._render_frame()
return result

def render(self):
"""This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`."""
if self.auto_rendering and self._latest_frame is not None:
frame = self._latest_frame
self._latest_frame = None
return frame
return self._render_frame()

def _render_frame(self):
Expand Down
8 changes: 6 additions & 2 deletions gymnasium/wrappers/vector/rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def step(
"""Perform a step in the base environment and render a frame to the screen."""
result = super().step(actions)
if self.auto_rendering:
self._render_frame()
self._latest_frame = self._render_frame()
return result

def reset(
Expand All @@ -77,11 +77,15 @@ def reset(
"""Reset the base environment and render a frame to the screen."""
result = super().reset(seed=seed, options=options)
if self.auto_rendering:
self._render_frame()
self._latest_frame = self._render_frame()
return result

def render(self):
"""This method doesn't do much, actual rendering is usually performed in :meth:`step` and :meth:`reset`."""
if self.auto_rendering and self._latest_frame is not None:
frame = self._latest_frame
self._latest_frame = None
return frame
return self._render_frame()

def _render_frame(self):
Expand Down
9 changes: 6 additions & 3 deletions tests/wrappers/test_human_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,23 @@ def test_human_rendering():
@pytest.mark.parametrize("env_id", ["CartPole-v1"])
@pytest.mark.parametrize("num_envs", [1, 3, 9])
@pytest.mark.parametrize("screen_size", [None])
def test_human_rendering_manual(env_id, num_envs, screen_size):
@pytest.mark.parametrize("auto_rendering", [False, True])
def test_human_rendering_manual(env_id, num_envs, screen_size, auto_rendering):
env = HumanRendering(
gym.make("CartPole-v1", render_mode="rgb_array", disable_env_checker=True),
auto_rendering=False,
auto_rendering=auto_rendering,
)
assert env.render_mode == "human"
assert env.auto_rendering == auto_rendering

env.reset()

for _ in range(75):
_, _, terminated, truncated, _ = env.step(env.action_space.sample())
if terminated or truncated:
env.reset()
rendering = env.render()
# output should match mode
rendering = env.render()
assert isinstance(rendering, np.ndarray)

env.close()
18 changes: 13 additions & 5 deletions tests/wrappers/vector/test_human_rendering.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,21 @@ def test_num_envs_screen_size(env_id, num_envs, screen_size):


def test_render_modes():
num_envs = 3
envs = HumanRendering(
gym.make_vec("CartPole-v1", num_envs=3, render_mode="rgb_array_list")
gym.make_vec("CartPole-v1", num_envs=num_envs, render_mode="rgb_array_list")
)
assert envs.render_mode == "human"

envs.reset()
for _ in range(25):
envs.step(envs.action_space.sample())
# output should match mode, list of environment rgb_arrays
rendering = envs.render()
assert isinstance(rendering, list)
assert len(rendering) == num_envs
assert isinstance(rendering[0], np.ndarray)

envs.close()

# HumanRenderer on human renderer should not work
Expand All @@ -48,19 +55,20 @@ def test_render_modes():
@pytest.mark.parametrize("env_id", ["CartPole-v1"])
@pytest.mark.parametrize("num_envs", [1, 3, 9])
@pytest.mark.parametrize("screen_size", [None])
def test_human_rendering_manual(env_id, num_envs, screen_size):
@pytest.mark.parametrize("auto_rendering", [False, True])
def test_human_rendering_manual(env_id, num_envs, screen_size, auto_rendering):
envs = gym.make_vec(env_id, num_envs=num_envs, render_mode="rgb_array")
envs = HumanRendering(envs, screen_size=screen_size, auto_rendering=False)
envs = HumanRendering(envs, screen_size=screen_size, auto_rendering=auto_rendering)

assert envs.render_mode == "human"
assert not envs.auto_rendering
assert envs.auto_rendering == auto_rendering

envs.reset()

# Test Manual render() call
envs.step(envs.action_space.sample())
rendering = envs.render()
# output should match mode, list of environment rgb_arrays
rendering = envs.render()
assert isinstance(rendering, list)
assert len(rendering) == num_envs
assert isinstance(rendering[0], np.ndarray)
Expand Down

0 comments on commit 583b3ba

Please sign in to comment.