95
95
from pettingzoo .utils import AgentSelector , wrappers
96
96
from pettingzoo .utils .conversions import parallel_wrapper_fn
97
97
98
- _image_library = {}
99
-
100
98
FPS = 20
101
99
102
100
__all__ = ["ManualPolicy" , "env" , "parallel_env" , "raw_env" ]
@@ -239,8 +237,7 @@ def __init__(
239
237
)
240
238
self .recentPistons = set () # Set of pistons that have touched the ball recently
241
239
self .time_penalty = time_penalty
242
- # TODO: this was a bad idea and the logic this uses should be removed at some point
243
- self .local_ratio = 0
240
+
244
241
self .ball_mass = ball_mass
245
242
self .ball_friction = ball_friction
246
243
self .ball_elasticity = ball_elasticity
@@ -466,8 +463,8 @@ def reset(self, seed=None, options=None):
466
463
- 6 * math .pi , 6 * math .pi
467
464
)
468
465
469
- self .lastX = int ( self .ball . position [ 0 ] - self . ball_radius )
470
- self .distance = self .lastX - self .wall_width
466
+ self .ball_prev_pos = self ._get_ball_position ( )
467
+ self .distance_to_wall_at_game_start = self .ball_prev_pos - self .wall_width
471
468
472
469
self .draw_background ()
473
470
self .draw ()
@@ -566,30 +563,6 @@ def draw(self):
566
563
)
567
564
self .draw_pistons ()
568
565
569
- def get_nearby_pistons (self ):
570
- # first piston = leftmost
571
- nearby_pistons = []
572
- ball_pos = int (self .ball .position [0 ] - self .ball_radius )
573
- closest = abs (self .pistonList [0 ].position .x - ball_pos )
574
- closest_piston_index = 0
575
- for i in range (self .n_pistons ):
576
- next_distance = abs (self .pistonList [i ].position .x - ball_pos )
577
- if next_distance < closest :
578
- closest = next_distance
579
- closest_piston_index = i
580
-
581
- if closest_piston_index > 0 :
582
- nearby_pistons .append (closest_piston_index - 1 )
583
- nearby_pistons .append (closest_piston_index )
584
- if closest_piston_index < self .n_pistons - 1 :
585
- nearby_pistons .append (closest_piston_index + 1 )
586
-
587
- return nearby_pistons
588
-
589
- def get_local_reward (self , prev_position , curr_position ):
590
- local_reward = 0.5 * (prev_position - curr_position )
591
- return local_reward
592
-
593
566
def render (self ):
594
567
if self .render_mode is None :
595
568
gymnasium .logger .warn (
@@ -613,6 +586,17 @@ def render(self):
613
586
else None
614
587
)
615
588
589
+ def _get_ball_position (self ) -> int :
590
+ """Return the leftmost x-position of the ball.
591
+
592
+ If the ball extends beyond the leftmost wall, return the
593
+ position of that wall-edge.
594
+ """
595
+ ball_position = int (self .ball .position [0 ] - self .ball_radius )
596
+ # check if the ball is touching/within the left-most wall.
597
+ clipped_ball_position = max (self .wall_width , ball_position )
598
+ return clipped_ball_position
599
+
616
600
def step (self , action ):
617
601
if (
618
602
self .terminations [self .agent_selection ]
@@ -633,30 +617,31 @@ def step(self, action):
633
617
634
618
self .space .step (self .dt )
635
619
if self ._agent_selector .is_last ():
636
- ball_min_x = int (self .ball .position [0 ] - self .ball_radius )
637
- ball_next_x = (
638
- self .ball .position [0 ]
639
- - self .ball_radius
640
- + self .ball .velocity [0 ] * self .dt
641
- )
642
- if ball_next_x <= self .wall_width + 1 :
620
+ ball_curr_pos = self ._get_ball_position ()
621
+
622
+ # A rough, first-order prediction (i.e. velocity-only) of the balls next position.
623
+ # The physics environment may bounce the ball off the wall in the next time-step
624
+ # without us first registering that win-condition.
625
+ ball_predicted_next_pos = ball_curr_pos + self .ball .velocity [0 ] * self .dt
626
+ # Include a single-pixel fudge-factor for the approximation.
627
+ if ball_predicted_next_pos <= self .wall_width + 1 :
643
628
self .terminate = True
644
- # ensures that the ball can't pass through the wall
645
- ball_min_x = max (self .wall_width , ball_min_x )
629
+
646
630
self .draw ()
647
- local_reward = self .get_local_reward (self .lastX , ball_min_x )
648
- # Opposite order due to moving right to left
649
- global_reward = (100 / self .distance ) * (self .lastX - ball_min_x )
631
+
632
+ # The negative one is included since the x-axis increases from left-to-right. And, if the x
633
+ # position decreases we want the reward to be positive, since the ball would have gotten closer
634
+ # to the left-wall.
635
+ global_reward = (
636
+ - 1
637
+ * (ball_curr_pos - self .ball_prev_pos )
638
+ * (100 / self .distance_to_wall_at_game_start )
639
+ )
650
640
if not self .terminate :
651
641
global_reward += self .time_penalty
652
- total_reward = [
653
- global_reward * (1 - self .local_ratio )
654
- ] * self .n_pistons # start with global reward
655
- local_pistons_to_reward = self .get_nearby_pistons ()
656
- for index in local_pistons_to_reward :
657
- total_reward [index ] += local_reward * self .local_ratio
658
- self .rewards = dict (zip (self .agents , total_reward ))
659
- self .lastX = ball_min_x
642
+
643
+ self .rewards = {agent : global_reward for agent in self .agents }
644
+ self .ball_prev_pos = ball_curr_pos
660
645
self .frames += 1
661
646
else :
662
647
self ._clear_rewards ()
0 commit comments