Skip to content

Commit 0cd46f4

Browse files
AlexPasquaSimRey
andcommitted
Created env 'catching point' (to be tested)
Co-authored-by: simrey <[email protected]>
1 parent 6994f2d commit 0cd46f4

File tree

2 files changed

+108
-1
lines changed

2 files changed

+108
-1
lines changed

sb3_contrib/common/envs/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@
33
InvalidActionEnvMultiBinary,
44
InvalidActionEnvMultiDiscrete,
55
)
6+
from sb3_contrib.common.envs.hybrid_actions_env import (
7+
CatchingPointEnv
8+
)
69

7-
__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete"]
10+
__all__ = ["InvalidActionEnvDiscrete", "InvalidActionEnvMultiBinary", "InvalidActionEnvMultiDiscrete", "CatchingPointEnv"]
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import gymnasium as gym
2+
from gymnasium import spaces
3+
import numpy as np
4+
5+
6+
class CatchingPointEnv(gym.Env):
7+
"""
8+
Enviornment for Hybrid PPO for the 'Catching Point' task of the paper
9+
'Hybrid Actor-Critic Reinforcement Learning in Parameterized Action Space', Fan et al.
10+
(https://arxiv.org/pdf/1903.01344)
11+
"""
12+
13+
def __init__(
14+
self,
15+
arena_size: float = 1.0,
16+
move_dist=0.05,
17+
catch_radius=0.05,
18+
max_catches=10,
19+
max_steps=200
20+
):
21+
super().__init__()
22+
self.max_steps = max_steps
23+
self.max_catches = max_catches
24+
self.arena_size = arena_size
25+
self.move_dist = move_dist
26+
self.catch_radius = catch_radius
27+
28+
# action space
29+
self.action_space = spaces.Tuple(
30+
spaces=(
31+
spaces.MultiDiscrete([2]), # MOVE=0, CATCH=1
32+
spaces.Box(low=-1.0, high=1.0, shape=(2,), dtype=np.float32) # direction
33+
)
34+
)
35+
36+
# observation: [agent_x, agent_y, target_x, target_y, catches_left, step_norm]
37+
obs_low = np.array([-arena_size, -arena_size, -arena_size, -arena_size, 0.0, 0.0], dtype=np.float32)
38+
obs_high= np.array([ arena_size, arena_size, arena_size, arena_size, float(max_catches), 1.0], dtype=np.float32)
39+
self.observation_space = spaces.Box(obs_low, obs_high, dtype=np.float32)
40+
41+
def reset(self) -> np.ndarray:
42+
"""
43+
Reset the environment to an initial state and return the initial observation.
44+
"""
45+
self.agent_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32)
46+
self.target_pos = self.np_random.uniform(-self.arena_size, self.arena_size, size=2).astype(np.float32)
47+
self.catches_used = 0
48+
self.step_count = 0
49+
return self._get_obs()
50+
51+
def step(self, action: tuple[np.ndarray, np.ndarray]) -> tuple[np.ndarray, float, bool, dict]:
52+
"""
53+
Take a step in the environment using the provided action.
54+
55+
:param action: A tuple containing the discrete action and continuous parameters.
56+
:return: observation, reward, done, info
57+
"""
58+
action_d = int(action[0][0])
59+
dir_vec = action[1]
60+
reward = 0.0
61+
done = False
62+
63+
# step penalty
64+
reward = -0.01
65+
66+
# MOVE
67+
if action_d == 0:
68+
norm = np.linalg.norm(dir_vec)
69+
dir_u = dir_vec / norm
70+
self.agent_pos = (self.agent_pos + dir_u * self.move_dist).astype(np.float32)
71+
# clamp to arena
72+
self.agent_pos = np.clip(self.agent_pos, -self.arena_size, self.arena_size)
73+
74+
# CATCH
75+
else:
76+
self.catches_used += 1
77+
dist = np.linalg.norm(self.agent_pos - self.target_pos)
78+
if dist <= self.catch_radius:
79+
reward = 1.0 # caught the target
80+
done = True
81+
else:
82+
if self.catches_used >= self.max_catches:
83+
done = True
84+
85+
self.step_count += 1
86+
if self.step_count >= self.max_steps:
87+
done = True
88+
89+
obs = self._get_obs()
90+
info = {"caught": (reward > 0)}
91+
return obs, float(reward), bool(done), info
92+
93+
def _get_obs(self) -> np.ndarray:
94+
"""
95+
Get the current observation.
96+
"""
97+
step_norm = self.step_count / self.max_steps
98+
catches_left = self.max_catches - self.catches_used
99+
obs = np.concatenate((
100+
self.agent_pos,
101+
self.target_pos,
102+
np.array([catches_left, step_norm], dtype=np.float32)
103+
))
104+
return obs

0 commit comments

Comments
 (0)