@@ -248,29 +248,18 @@ def make_env():
248248
249249class MultiDimensionalActionSpaceEnv (gym .Env ):
250250 def __init__ (self ):
251- self .observation_space = gym .spaces .Box (
252- low = - 1 ,
253- high = 1 ,
254- shape = (10 ,),
255- dtype = np .float32 ,
256- )
257-
258- self .action_space = gym .spaces .Box (
259- low = - 1 ,
260- high = 1 ,
261- shape = (2 , 2 ),
262- dtype = np .float32 ,
263- )
251+ self .observation_space = spaces .Box (low = - 1 , high = 1 , shape = (10 ,), dtype = np .float32 )
252+ self .action_space = spaces .Box (low = - 1 , high = 1 , shape = (2 , 2 ), dtype = np .float32 )
264253
265254 def reset (self , seed = None , options = None ):
266255 super ().reset (seed = seed )
267256 return self .observation_space .sample (), {}
268257
269258 def step (self , action ):
270- return self .observation_space .sample (), 1 , False , False , {}
259+ return self .observation_space .sample (), 1 , np . random . rand () > 0.8 , False , {}
271260
272261
273262def test_ppo_multi_dimensional_action_space ():
274- env = make_vec_env ( MultiDimensionalActionSpaceEnv , n_envs = 1 )
275- model = RecurrentPPO ("MlpLstmPolicy" , env )
276- model . learn ( 1 )
263+ env = MultiDimensionalActionSpaceEnv ( )
264+ model = RecurrentPPO ("MlpLstmPolicy" , env , n_steps = 64 , n_epochs = 2 ). learn ( 64 )
265+ evaluate_policy ( model , model . get_env () )
0 commit comments