Skip to content

Commit eeaf94e

Browse files
committed
Remove code references to local_reward & make some refactors for improved clarity.
1 parent 369848b commit eeaf94e

File tree

1 file changed

+35
-50
lines changed

1 file changed

+35
-50
lines changed

pettingzoo/butterfly/pistonball/pistonball.py

+35-50
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,6 @@
9595
from pettingzoo.utils import AgentSelector, wrappers
9696
from pettingzoo.utils.conversions import parallel_wrapper_fn
9797

98-
_image_library = {}
99-
10098
FPS = 20
10199

102100
__all__ = ["ManualPolicy", "env", "parallel_env", "raw_env"]
@@ -239,8 +237,7 @@ def __init__(
239237
)
240238
self.recentPistons = set() # Set of pistons that have touched the ball recently
241239
self.time_penalty = time_penalty
242-
# TODO: this was a bad idea and the logic this uses should be removed at some point
243-
self.local_ratio = 0
240+
244241
self.ball_mass = ball_mass
245242
self.ball_friction = ball_friction
246243
self.ball_elasticity = ball_elasticity
@@ -466,8 +463,8 @@ def reset(self, seed=None, options=None):
466463
-6 * math.pi, 6 * math.pi
467464
)
468465

469-
self.lastX = int(self.ball.position[0] - self.ball_radius)
470-
self.distance = self.lastX - self.wall_width
466+
self.ball_prev_pos = self._get_ball_position()
467+
self.distance_to_wall_at_game_start = self.ball_prev_pos - self.wall_width
471468

472469
self.draw_background()
473470
self.draw()
@@ -566,30 +563,6 @@ def draw(self):
566563
)
567564
self.draw_pistons()
568565

569-
def get_nearby_pistons(self):
570-
# first piston = leftmost
571-
nearby_pistons = []
572-
ball_pos = int(self.ball.position[0] - self.ball_radius)
573-
closest = abs(self.pistonList[0].position.x - ball_pos)
574-
closest_piston_index = 0
575-
for i in range(self.n_pistons):
576-
next_distance = abs(self.pistonList[i].position.x - ball_pos)
577-
if next_distance < closest:
578-
closest = next_distance
579-
closest_piston_index = i
580-
581-
if closest_piston_index > 0:
582-
nearby_pistons.append(closest_piston_index - 1)
583-
nearby_pistons.append(closest_piston_index)
584-
if closest_piston_index < self.n_pistons - 1:
585-
nearby_pistons.append(closest_piston_index + 1)
586-
587-
return nearby_pistons
588-
589-
def get_local_reward(self, prev_position, curr_position):
590-
local_reward = 0.5 * (prev_position - curr_position)
591-
return local_reward
592-
593566
def render(self):
594567
if self.render_mode is None:
595568
gymnasium.logger.warn(
@@ -613,6 +586,17 @@ def render(self):
613586
else None
614587
)
615588

589+
def _get_ball_position(self) -> int:
590+
"""Return the leftmost x-position of the ball.
591+
592+
If the ball extends beyond the leftmost wall, return the
593+
position of that wall-edge.
594+
"""
595+
ball_position = int(self.ball.position[0] - self.ball_radius)
596+
# check if the ball is touching/within the left-most wall.
597+
clipped_ball_position = max(self.wall_width, ball_position)
598+
return clipped_ball_position
599+
616600
def step(self, action):
617601
if (
618602
self.terminations[self.agent_selection]
@@ -633,30 +617,31 @@ def step(self, action):
633617

634618
self.space.step(self.dt)
635619
if self._agent_selector.is_last():
636-
ball_min_x = int(self.ball.position[0] - self.ball_radius)
637-
ball_next_x = (
638-
self.ball.position[0]
639-
- self.ball_radius
640-
+ self.ball.velocity[0] * self.dt
641-
)
642-
if ball_next_x <= self.wall_width + 1:
620+
ball_curr_pos = self._get_ball_position()
621+
622+
# A rough, first-order prediction (i.e. velocity-only) of the balls next position.
623+
# The physics environment may bounce the ball off the wall in the next time-step
624+
# without us first registering that win-condition.
625+
ball_predicted_next_pos = ball_curr_pos + self.ball.velocity[0] * self.dt
626+
# Include a single-pixel fudge-factor for the approximation.
627+
if ball_predicted_next_pos <= self.wall_width + 1:
643628
self.terminate = True
644-
# ensures that the ball can't pass through the wall
645-
ball_min_x = max(self.wall_width, ball_min_x)
629+
646630
self.draw()
647-
local_reward = self.get_local_reward(self.lastX, ball_min_x)
648-
# Opposite order due to moving right to left
649-
global_reward = (100 / self.distance) * (self.lastX - ball_min_x)
631+
632+
# The negative one is included since the x-axis increases from left-to-right. And, if the x
633+
# position decreases we want the reward to be positive, since the ball would have gotten closer
634+
# to the left-wall.
635+
global_reward = (
636+
-1
637+
* (ball_curr_pos - self.ball_prev_pos)
638+
* (100 / self.distance_to_wall_at_game_start)
639+
)
650640
if not self.terminate:
651641
global_reward += self.time_penalty
652-
total_reward = [
653-
global_reward * (1 - self.local_ratio)
654-
] * self.n_pistons # start with global reward
655-
local_pistons_to_reward = self.get_nearby_pistons()
656-
for index in local_pistons_to_reward:
657-
total_reward[index] += local_reward * self.local_ratio
658-
self.rewards = dict(zip(self.agents, total_reward))
659-
self.lastX = ball_min_x
642+
643+
self.rewards = {agent: global_reward for agent in self.agents}
644+
self.ball_prev_pos = ball_curr_pos
660645
self.frames += 1
661646
else:
662647
self._clear_rewards()

0 commit comments

Comments
 (0)