Skip to content

[Bug]: Error in environments with multi-discrete action spaces when the nvec dimension is higher than 1D #2207

@unexploredtest

Description

@unexploredtest

🐛 Bug

If the action space of a custom environment is multi-discrete and higher than 1d, for example [[2], [2]], then the program exits with an error.

Code example

import numpy as np
import gymnasium as gym
from gymnasium import spaces
from gymnasium.wrappers.common import TimeLimit

from stable_baselines3 import PPO
from stable_baselines3.common.env_checker import check_env

class SimpleEnv(gym.Env):
    metadata = {"render_modes": [], "render_fps": 30}

    def __init__(self):
        super().__init__()
        # Multi-discrete action space with shape [2, 1]
        self.action_space = spaces.MultiDiscrete([[2],[2]])
        self.observation_space = spaces.Box(low=0.0, high=1.0, shape=(4,), dtype=np.float32)

        self.agent = np.zeros(2, dtype=np.float32)
        self.goal = np.zeros(2, dtype=np.float32)

    def reset(self, *, seed: int | None = None, options: dict | None = None):
        super().reset(seed=seed)

        self.agent = self.np_random.uniform(0.0, 1.0, size=2).astype(np.float32)
        self.goal = self.np_random.uniform(0.0, 1.0, size=2).astype(np.float32)
        obs = np.concatenate([self.agent, self.goal]).astype(np.float32)

        return obs, {}

    def step(self, action):
        if not self.action_space.contains(action):
            raise ValueError("Invalid action")

        if action[0] == 1:
            self.agent = self.agent + np.array([0.02, 0.0], dtype=np.float32)
        else:
            self.agent = self.agent + np.array([-0.02, 0.0], dtype=np.float32)
        if action[1] == 1:
            self.agent = self.agent + np.array([0.0, 0.02], dtype=np.float32)
        else:
            self.agent = self.agent + np.array([0.0, -0.02], dtype=np.float32)

        self.agent = np.clip(self.agent, 0.0, 1.0)

        obs = np.concatenate([self.agent, self.goal]).astype(np.float32)
        distance = np.linalg.norm(self.agent - self.goal)
        reward = -float(distance)
        terminated = bool(distance < 0.04)
        truncated = False

        info = {}
        return obs, reward, terminated, truncated, info

    def render(self):
        pass

    def close(self):
        pass

env = TimeLimit(SimpleEnv(), 100)
model = PPO('MlpPolicy', env, verbose=1)
model.learn(total_timesteps=200_000)

Relevant log output / Error message

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Traceback (most recent call last):
  File "/run/media/user/temp_win/git/stable-baselines3/test.py", line 61, in <module>
    model = PPO('MlpPolicy', env, verbose=1)
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/ppo/ppo.py", line 171, in __init__
    self._setup_model()
    ~~~~~~~~~~~~~~~~~^^
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/ppo/ppo.py", line 174, in _setup_model
    super()._setup_model()
    ~~~~~~~~~~~~~~~~~~~~^^
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 135, in _setup_model
    self.policy = self.policy_class(  # type: ignore[assignment]
                  ~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
        self.observation_space, self.action_space, self.lr_schedule, use_sde=self.use_sde, **self.policy_kwargs
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    )
    ^
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/policies.py", line 535, in __init__
    self._build(lr_schedule)
    ~~~~~~~~~~~^^^^^^^^^^^^^
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/policies.py", line 605, in _build
    self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi)
                      ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/run/media/user/temp_win/git/stable-baselines3/stable_baselines3/common/distributions.py", line 342, in proba_distribution_net
    action_logits = nn.Linear(latent_dim, sum(self.action_dims))
  File "/run/media/user/temp_win/git/pong-rl/venv/lib/python3.13/site-packages/torch/nn/modules/linear.py", line 106, in __init__
    torch.empty((out_features, in_features), **factory_kwargs)
    ~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)
 * (tuple of ints size, *, torch.memory_format memory_format = None, Tensor out = None, torch.dtype dtype = None, torch.layout layout = None, torch.device device = None, bool pin_memory = False, bool requires_grad = False)

System Info

  • OS: Linux-6.15.8-arch1-1-x86_64-with-glibc2.41 # 1 SMP PREEMPT_DYNAMIC Thu, 24 Jul 2025 18:18:11 +0000
  • Python: 3.13.5
  • Stable-Baselines3: 2.8.0a2
  • PyTorch: 2.8.0+cpu
  • GPU Enabled: False
  • Numpy: 2.3.3
  • Cloudpickle: 3.1.1
  • Gymnasium: 1.2.0

Checklist

Metadata

Metadata

Assignees

No one assigned

    Labels

    custom gym envIssue related to Custom Gym Env

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions