Skip to content

Commit a07bb1b

Browse files
committed
Update test and version
1 parent dee0acd commit a07bb1b

File tree

3 files changed

+8
-19
lines changed

3 files changed

+8
-19
lines changed

docs/misc/changelog.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Changelog
44
==========
55

6-
Release 2.8.0a0 (WIP)
6+
Release 2.8.0a1 (WIP)
77
--------------------------
88

99
Breaking Changes:

sb3_contrib/version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.8.0a0
1+
2.8.0a1

tests/test_lstm.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,29 +248,18 @@ def make_env():
248248

249249
class MultiDimensionalActionSpaceEnv(gym.Env):
250250
def __init__(self):
251-
self.observation_space = gym.spaces.Box(
252-
low=-1,
253-
high=1,
254-
shape=(10,),
255-
dtype=np.float32,
256-
)
257-
258-
self.action_space = gym.spaces.Box(
259-
low=-1,
260-
high=1,
261-
shape=(2, 2),
262-
dtype=np.float32,
263-
)
251+
self.observation_space = spaces.Box(low=-1, high=1, shape=(10,), dtype=np.float32)
252+
self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32)
264253

265254
def reset(self, seed=None, options=None):
266255
super().reset(seed=seed)
267256
return self.observation_space.sample(), {}
268257

269258
def step(self, action):
270-
return self.observation_space.sample(), 1, False, False, {}
259+
return self.observation_space.sample(), 1, np.random.rand() > 0.8, False, {}
271260

272261

273262
def test_ppo_multi_dimensional_action_space():
274-
env = make_vec_env(MultiDimensionalActionSpaceEnv, n_envs=1)
275-
model = RecurrentPPO("MlpLstmPolicy", env)
276-
model.learn(1)
263+
env = MultiDimensionalActionSpaceEnv()
264+
model = RecurrentPPO("MlpLstmPolicy", env, n_steps=64, n_epochs=2).learn(64)
265+
evaluate_policy(model, model.get_env())

0 commit comments

Comments
 (0)