Skip to content

Commit a86bbba

Browse files
committed
test: add unit test for multi-dimensional action spaces to PPO variants
1 parent 37614b2 commit a86bbba

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/test_lstm.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,32 @@ def make_env():
244244
# In CartPole-v1, a non-recurrent policy can easily get >= 450.
245245
# In CartPoleNoVelEnv, a non-recurrent policy doesn't get more than ~50.
246246
evaluate_policy(model, env, reward_threshold=450)
247+
248+
249+
class MultiDimensionalActionSpaceEnv(gym.Env):
250+
def __init__(self):
251+
self.observation_space = gym.spaces.Box(
252+
low=-1,
253+
high=1,
254+
shape=(10,),
255+
dtype=np.float32,
256+
)
257+
258+
self.action_space = gym.spaces.Box(
259+
low=-1,
260+
high=1,
261+
shape=(2, 2),
262+
dtype=np.float32,
263+
)
264+
265+
def reset(self, seed=None, options=None):
266+
super().reset(seed=seed)
267+
return self.observation_space.sample(), {}
268+
269+
def step(self, action):
270+
return self.observation_space.sample(), 1, False, False, {}
271+
272+
def test_ppo_multi_dimensional_action_space():
273+
env = make_vec_env(MultiDimensionalActionSpaceEnv, n_envs=1)
274+
model = RecurrentPPO("MlpLstmPolicy", env)
275+
model.learn(1)

0 commit comments

Comments
 (0)