diff --git a/coltra/envs/__init__.py b/coltra/envs/__init__.py index 21bd48f2..4398af3a 100644 --- a/coltra/envs/__init__.py +++ b/coltra/envs/__init__.py @@ -2,7 +2,6 @@ from .base_env import MultiAgentEnv from .subproc_vec_env import SubprocVecEnv from .probe_envs import probe_env_classes -from .gym_envs import MultiGymEnv, import_bullet +from .gym_envs import MultiGymEnv from .base_env import Observation, Action from .smartnav_envs import SmartNavEnv -from .pettingzoo_envs import PettingZooEnv diff --git a/coltra/envs/gym_envs.py b/coltra/envs/gym_envs.py index 499bad61..095896fa 100644 --- a/coltra/envs/gym_envs.py +++ b/coltra/envs/gym_envs.py @@ -10,11 +10,6 @@ from coltra.envs.subproc_vec_env import VecEnv, SubprocVecEnv -def import_bullet(): - # noinspection PyUnresolvedReferences - import pybullet_envs - - class MultiGymEnv(MultiAgentEnv): """ A wrapper for environments that can be `gym.make`'d @@ -32,8 +27,6 @@ def __init__( super().__init__(seed) if wrappers is None: wrappers = [] - if "Bullet" in env_name: - import_fn = import_bullet import_fn() diff --git a/coltra/envs/pettingzoo_envs.py b/coltra/envs/pettingzoo_envs.py deleted file mode 100644 index e37268ae..00000000 --- a/coltra/envs/pettingzoo_envs.py +++ /dev/null @@ -1,66 +0,0 @@ -from typing import Callable, Any -import numpy as np -from gym.spaces import Discrete - -from pettingzoo.utils.env import ParallelEnv - -from coltra.buffers import Observation -from coltra.envs import MultiAgentEnv, SubprocVecEnv -from coltra.envs.base_env import VecEnv, ActionDict, StepReturn, ObsDict - - -class PettingZooEnv(MultiAgentEnv): - def __init__(self, env_creator: Callable[..., ParallelEnv], **kwargs): - super().__init__() - self.pz_env = env_creator(**kwargs) - self.active_agents = self.pz_env.possible_agents - - agent_name = self.pz_env.possible_agents[0] - self.action_space = self.pz_env.action_spaces[agent_name] - self.observation_space = self.pz_env.observation_spaces[agent_name] - - self.is_discrete_action = isinstance(self.action_space, Discrete) - - def _embed_observation(self, obs: np.ndarray) -> Observation: - shape = obs.shape - if len(shape) == 1: - return Observation(vector=obs) - elif len(shape) == 2: - return Observation(buffer=obs) - elif len(shape) == 3: - return Observation(image=obs) - else: - raise ValueError(f"Observation shape {obs.shape} not supported") - - def reset(self, *args, **kwargs) -> ObsDict: - obs = self.pz_env.reset() - return {key: self._embed_observation(obs[key]) for key in obs} - - def step(self, action_dict: ActionDict) -> StepReturn: - if self.is_discrete_action: - action = { - agent_id: action_dict[agent_id].discrete for agent_id in action_dict - } - else: - action = { - agent_id: action_dict[agent_id].continuous for agent_id in action_dict - } - - obs, reward, done, info = self.pz_env.step(action) - - if all(done.values()): - obs = self.pz_env.reset() - - obs = {key: self._embed_observation(obs[key]) for key in obs} - return obs, reward, done, info - - def render(self, mode="rgb_array"): - return self.pz_env.render(mode) - - @classmethod - def get_venv(cls, workers: int = 8, **env_kwargs) -> SubprocVecEnv: - - venv = SubprocVecEnv( - [cls.get_env_creator(**env_kwargs) for i in range(workers)] - ) - return venv diff --git a/requirements.txt b/requirements.txt index 9fffea05..11795774 100644 --- a/requirements.txt +++ b/requirements.txt @@ -17,7 +17,6 @@ ipykernel>=5.4.3 numba>=0.54.0 pytest>=6.2.2 coverage>=5.5.0 -pybullet>=3.1.9 opencv-python~=3.4.15.55 cloudpickle~=2.0.0 pillow~=8.4.0 diff --git a/scripts/enjoy_gym.py b/scripts/enjoy_gym.py index 426c8f4c..d24374ca 100644 --- a/scripts/enjoy_gym.py +++ b/scripts/enjoy_gym.py @@ -12,8 +12,6 @@ from coltra.trainers import PPOCrowdTrainer from coltra.envs import MultiGymEnv -import pybullet_envs - class Parser(BaseParser): path: str diff --git a/scripts/train_gym.py b/scripts/train_gym.py index 6fdf03f3..bcf6ffc8 100644 --- a/scripts/train_gym.py +++ b/scripts/train_gym.py @@ -11,8 +11,6 @@ from coltra.trainers import PPOCrowdTrainer from coltra.envs import MultiGymEnv -import pybullet_envs - from coltra.wrappers import ObsVecNormWrapper, LastRewardWrapper from coltra.wrappers.agent_wrappers import RetNormWrapper diff --git a/tests/test_pybullet.py b/tests/test_pybullet.py deleted file mode 100644 index 13941266..00000000 --- a/tests/test_pybullet.py +++ /dev/null @@ -1,10 +0,0 @@ -from coltra.envs.gym_envs import MultiGymEnv, import_bullet - - -def test_init(): - env = MultiGymEnv.get_venv( - 8, env_name="HopperBulletEnv-v0", seed=0, import_fn=import_bullet - ) - obs = env.reset() - - assert len(obs) == 8