1717from stable_baselines3 .common .vec_env .patch_gym import _patch_env
1818
1919
20- def _worker (
20+ def _worker ( # noqa: C901
2121 remote : mp .connection .Connection ,
2222 parent_remote : mp .connection .Connection ,
2323 env_fn_wrapper : CloudpickleWrapper ,
@@ -58,6 +58,12 @@ def _worker(
5858 remote .send (method (* data [1 ], ** data [2 ]))
5959 elif cmd == "get_attr" :
6060 remote .send (env .get_wrapper_attr (data ))
61+ elif cmd == "has_attr" :
62+ try :
63+ env .get_wrapper_attr (data )
64+ remote .send (True )
65+ except AttributeError :
66+ remote .send (False )
6167 elif cmd == "set_attr" :
6268 remote .send (setattr (env , data [0 ], data [1 ])) # type: ignore[func-returns-value]
6369 elif cmd == "is_wrapped" :
@@ -66,6 +72,8 @@ def _worker(
6672 raise NotImplementedError (f"`{ cmd } ` is not implemented in the worker" )
6773 except EOFError :
6874 break
75+ except KeyboardInterrupt :
76+ break
6977
7078
7179class SubprocVecEnv (VecEnv ):
@@ -165,6 +173,13 @@ def get_images(self) -> Sequence[Optional[np.ndarray]]:
165173 outputs = [pipe .recv () for pipe in self .remotes ]
166174 return outputs
167175
176+ def has_attr (self , attr_name : str ) -> bool :
177+ """Check if an attribute exists for a vectorized environment. (see base class)."""
178+ target_remotes = self ._get_target_remotes (indices = None )
179+ for remote in target_remotes :
180+ remote .send (("has_attr" , attr_name ))
181+ return all ([remote .recv () for remote in target_remotes ])
182+
168183 def get_attr (self , attr_name : str , indices : VecEnvIndices = None ) -> list [Any ]:
169184 """Return attribute from vectorized environment (see base class)."""
170185 target_remotes = self ._get_target_remotes (indices )
0 commit comments