-
Notifications
You must be signed in to change notification settings - Fork 2k
Open
Labels
custom gym envIssue related to Custom Gym EnvIssue related to Custom Gym Env
Description
🐛 Bug
If the action space of a custom environment is multi-discrete and higher than 1d, for example [[2], [2]], then the program exits with an error.
Code example
import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers.common import TimeLimit
from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env
class SimpleEnv(gym.Env):
metadata = {"render_modes": [], "render_fps": 30}
def __init__(self):
super().__init__()
# Multi-discrete action space with shape [2, 1]
self.action_space = spaces.MultiDiscrete([[2],[2]])
self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float32)
self.agent = np.zeros(2, dtype=np.float32)
self.goal = np.zeros(2, dtype=np.float32)
def reset(self, *, seed: int | None = None, options: dict | None = None):
super().reset(seed=seed)
self.agent = self.np_random.uniform(0.0, 1.0, size=2).astype(np.float32)
self.goal = self.np_random.uniform(0.0, 1.0, size=2).astype(np.float32)
obs = np.concatenate([self.agent, self.goal]).astype(np.float32)
return obs, {}
def step(self, action):
if not self.action_space.contains(action):
raise ValueError("Invalid action")
if action[0] == 1:
self.agent = self.agent + np.array([0.02, 0.0], dtype=np.float32)
else:
self.agent = self.agent + np.array([-0.02, 0.0], dtype=np.float32)
if action[1] == 1:
self.agent = self.agent + np.array([0.0, 0.02], dtype=np.float32)
else:
self.agent = self.agent + np.array([0.0, -0.02], dtype=np.float32)
self.agent = np.clip(self.agent, 0.0, 1.0)
obs = np.concatenate([self.agent, self.goal]).astype(np.float32)
distance = np.linalg.norm(self.agent - self.goal)
reward = -float(distance)
terminated = bool(distance < 0.04)
truncated = False
info = {}
return obs, reward, terminated, truncated, info
def render(self):
pass
def close(self):
pass
env = TimeLimit(SimpleEnv(), 100)
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=200_000)Relevant log output / Error message
Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Traceback (most recent call last):
File "/run/media/user/temp_win/git/stable-baselines3/test.py", line 61, in <module>
model = PPO('MlpPolicy', env, verbose=1)
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/ppo/ppo.py", line 171, in __init__
self._setup_model()
~~~~~~~~~~~~~~~~~^^
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/ppo/ppo.py", line 174, in _setup_model
super()._setup_model()
~~~~~~~~~~~~~~~~~~~~^^
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 135, in _setup_model
self.policy = self.policy_class( # type: ignore[assignment]
~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/policies.py", line 535, in __init__
self._build(lr_schedule)
~~~~~~~~~~~^^^^^^^^^^^^^
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/policies.py", line 605, in _build
self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/distributions.py", line 342, in proba_distribution_net
action_logits = nn.Linear(latent_dim, sum(self.action_dims))
File "/run/media/user/temp_win/git/pong-rl/venv/lib/python3.13/site-packages/torch/nn/modules/linear.py", line 106, in __init__
torch.empty((out_features, in_features), **factory_kwargs)
~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
* (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
* (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)System Info
- OS: Linux-6.15.8-arch1-1-x86_64-with-glibc2.41 # 1 SMP PREEMPT_DYNAMIC Thu, 24 Jul 2025 18:18:11 +0000
- Python: 3.13.5
- Stable-Baselines3: 2.8.0a2
- PyTorch: 2.8.0+cpu
- GPU Enabled: False
- Numpy: 2.3.3
- Cloudpickle: 3.1.1
- Gymnasium: 1.2.0
Checklist
- I have checked that there is no similar issue in the repo
- I have read the documentation
- I have provided a minimal and working example to reproduce the bug
- I have checked my env using the env checker
- I've used the markdown code blocks for both code and stack traces.
Metadata
Metadata
Assignees
Labels
custom gym envIssue related to Custom Gym EnvIssue related to Custom Gym Env