diff --git a/pettingzoo/butterfly/pistonball/pistonball.py b/pettingzoo/butterfly/pistonball/pistonball.py index f40927e93..b4ce4436b 100644 --- a/pettingzoo/butterfly/pistonball/pistonball.py +++ b/pettingzoo/butterfly/pistonball/pistonball.py @@ -95,8 +95,6 @@ from pettingzoo.utils import AgentSelector, wrappers from pettingzoo.utils.conversions import parallel_wrapper_fn -_image_library = {} - FPS = 20 __all__ = ["ManualPolicy", "env", "parallel_env", "raw_env"] @@ -239,8 +237,7 @@ def __init__( ) self.recentPistons = set() # Set of pistons that have touched the ball recently self.time_penalty = time_penalty - # TODO: this was a bad idea and the logic this uses should be removed at some point - self.local_ratio = 0 + self.ball_mass = ball_mass self.ball_friction = ball_friction self.ball_elasticity = ball_elasticity @@ -466,8 +463,8 @@ def reset(self, seed=None, options=None): -6 * math.pi, 6 * math.pi ) - self.lastX = int(self.ball.position[0] - self.ball_radius) - self.distance = self.lastX - self.wall_width + self.ball_prev_pos = self._get_ball_position() + self.distance_to_wall_at_game_start = self.ball_prev_pos - self.wall_width self.draw_background() self.draw() @@ -566,30 +563,6 @@ def draw(self): ) self.draw_pistons() - def get_nearby_pistons(self): - # first piston = leftmost - nearby_pistons = [] - ball_pos = int(self.ball.position[0] - self.ball_radius) - closest = abs(self.pistonList[0].position.x - ball_pos) - closest_piston_index = 0 - for i in range(self.n_pistons): - next_distance = abs(self.pistonList[i].position.x - ball_pos) - if next_distance < closest: - closest = next_distance - closest_piston_index = i - - if closest_piston_index > 0: - nearby_pistons.append(closest_piston_index - 1) - nearby_pistons.append(closest_piston_index) - if closest_piston_index < self.n_pistons - 1: - nearby_pistons.append(closest_piston_index + 1) - - return nearby_pistons - - def get_local_reward(self, prev_position, curr_position): - local_reward = 0.5 * (prev_position - curr_position) - return local_reward - def render(self): if self.render_mode is None: gymnasium.logger.warn( @@ -612,6 +585,15 @@ def render(self): if self.render_mode == "rgb_array" else None ) + + def _get_ball_position(self) -> int: + """Return the leftmost x-position of the ball. If the ball + extends beyond the leftmost wall, return the position of that + wall-edge.""" + ball_position = int(self.ball.position[0] - self.ball_radius) + # check if the ball is touching/within the left-most wall. + clipped_ball_position = max(self.wall_width, ball_position) + return clipped_ball_position def step(self, action): if ( @@ -633,30 +615,30 @@ def step(self, action): self.space.step(self.dt) if self._agent_selector.is_last(): - ball_min_x = int(self.ball.position[0] - self.ball_radius) - ball_next_x = ( - self.ball.position[0] - - self.ball_radius - + self.ball.velocity[0] * self.dt + ball_curr_pos = self._get_ball_position() + + # A rough, first-order prediction (i.e. velocity-only) of the balls next position. + # The physics environment may bounce the ball off the wall in the next time-step + # without us first registering that win-condition. + ball_predicted_next_pos = ( + ball_curr_pos + + self.ball.velocity[0] * self.dt ) - if ball_next_x <= self.wall_width + 1: + # Include a single-pixel fudge-factor for the approximation. + if ball_predicted_next_pos <= self.wall_width + 1: self.terminate = True - # ensures that the ball can't pass through the wall - ball_min_x = max(self.wall_width, ball_min_x) + self.draw() - local_reward = self.get_local_reward(self.lastX, ball_min_x) - # Opposite order due to moving right to left - global_reward = (100 / self.distance) * (self.lastX - ball_min_x) + + # The negative one is included since the x-axis increases from left-to-right. And, if the x + # position decreases we want the reward to be positive, since the ball would have gotten closer + # to the left-wall. + global_reward = -1 * (ball_curr_pos - self.ball_prev_pos) * (100 / self.distance_to_wall_at_game_start) if not self.terminate: global_reward += self.time_penalty - total_reward = [ - global_reward * (1 - self.local_ratio) - ] * self.n_pistons # start with global reward - local_pistons_to_reward = self.get_nearby_pistons() - for index in local_pistons_to_reward: - total_reward[index] += local_reward * self.local_ratio - self.rewards = dict(zip(self.agents, total_reward)) - self.lastX = ball_min_x + + self.rewards = {agent: global_reward for agent in self.agents} + self.ball_prev_pos = ball_curr_pos self.frames += 1 else: self._clear_rewards()