Skip to content

No example for UnityMLAgentsEnv or Wrapper for single or multiagent training #2781

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
kylelevy opened this issue Feb 12, 2025 · 11 comments
Open

Comments

@kylelevy
Copy link

Following up on a discussion post.

TLDR; Trying to create an example notebook for UnityMLAgentsEnv/Wrapper but am unable to find docs or reference to how to interact with the keys that are produced for the environment. I have seen a YouTube video referenced by devs showing that it is possible but not how. If I can get it working I would love to contribute it to the examples.

Discussed in #2697

Originally posted by kylelevy January 15, 2025
Hello there,

I am new to TorchRL and am trying to use it to train a PPO algorithm in Unity-MLAgents. Currently, I am just trying to get a head_balance example scene running but have been having some difficulty using the env as it does not line up with the setup from the other tutorials.

The UnityMLAgentsEnv is working and returns an env with the 12 agents in the scene for the head balance. Like the UnityMLAgentsEnv Docs suggest in their example, each agent is inside one group in the TensorDict and each has its own fields such as continuous_action and the rollout works.

The problem however, is that the keys are not like either the Multiagent PPO Tutorial or the Multiagent DDPG Tutorial and I cannot find an example of how I can go about this format. In both tutorials, the expected keys for the other environments are ('agent', 'action'), ('agent', observation), etc, being that all agents are homogeneous and stacked into one vector right from the environment. The MLAgents head_balance example is not stacked and so I am not sure how to correctly apply the individual agent keys to the Policy or Critic modules.

I have been working on getting this example up and running for a little while and find myself stuck with how to correctly interface this style of environment with the different modules. Could I please get some advice or direction on how to go about this?

P.S. if I can get the head_balance working with TorchRL and the UnityMLAgentsEnv, I would be more than happy to open a pull request and contribute it for others to avoid the same headaches.

Setup:

  • python3.12
  • torchrl==0.6.0
  • tensordict==0.6.1
  • mlagents==0.28.0
  • mlagents-env==0.28.0

Code:

import multiprocessing
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    TransformedEnv,
    RewardSum
)
from torchrl.envs import UnityMLAgentsEnv, MarlGroupMapType
from torchrl.envs.utils import check_env_specs
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal, AdditiveGaussianModule
from tqdm import tqdm

# Devices
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

# Sampling
frames_per_batch = 6_000  # Number of team frames collected per training iteration
n_iters = 10  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 30  # Number of optimization steps per training iteration
minibatch_size = 400  # Size of the mini-batches in each optimization step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

base_env = UnityMLAgentsEnv(registered_name="3DBall", device=device, group_map=MarlGroupMapType.ALL_IN_ONE_GROUP)

env = TransformedEnv(
    base_env,
    RewardSum(
        in_keys=[key for key in base_env.reward_keys if key[2] == "reward"], # exclude group reward
        reset_keys=base_env.reset_keys
    )
)

check_env_specs(base_env)

n_rollout_steps = 5
rollout = env.rollout(n_rollout_steps)

share_parameters_policy = True

policy_net = nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs=env.observation_spec['agents']['agent_0']['observation_0'].shape[-1],
        n_agent_outputs=env.action_spec['agents']['agent_0']['continuous_action'].shape[-1],
        n_agents=len(env.group_map['agents']),
        centralised=False,
        share_params=share_parameters_policy,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh
    ),
    NormalParamExtractor(),
)

policy_module = TensorDictModule(
    policy_net, 
    in_keys=[("agents", agent, "observation_0") for agent in env.group_map["agents"]],
    out_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
)

policy = ProbabilisticActor(
    module=policy_module,
    spec=env.full_action_spec["agents", "agent_0", "continuous_action"],
    in_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
    out_keys=[("agents", agent, "continuous_action") for agent in env.group_map["agents"]],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec['agents']['agent_0']['continuous_action'].space.low,
        "high": env.action_spec['agents']['agent_0']['continuous_action'].space.high,
    },
    return_log_prob=False,
)

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 1
----> 1 policy = ProbabilisticActor(
      2     module=policy_module,
      3     spec=env.full_action_spec["agents", "agent_0", "continuous_action"],
      4     in_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
      5     out_keys=[("agents", agent, "continuous_action") for agent in env.group_map["agents"]],
      6     distribution_class=TanhNormal,
      7     distribution_kwargs={
      8         "low": env.action_spec['agents']['agent_0']['continuous_action'].space.low,
      9         "high": env.action_spec['agents']['agent_0']['continuous_action'].space.high,
     10     },
     11     return_log_prob=False,
     12 )

File c:\Users\ky097697\Development\distributed-rl-framework\venv\Lib\site-packages\torchrl\modules\tensordict_module\actors.py:390, in ProbabilisticActor.__init__(self, module, in_keys, out_keys, spec, **kwargs)
    385 if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite):
    386     spec = Composite({out_keys[0]: spec})
    388 super().__init__(
    389     module,
--> 390     SafeProbabilisticModule(
    391         in_keys=in_keys, out_keys=out_keys, spec=spec, **kwargs
    392     ),
    393 )

File c:\Users\ky097697\Development\distributed-rl-framework\venv\Lib\site-packages\torchrl\modules\tensordict_module\probabilistic.py:132, in SafeProbabilisticModule.__init__(self, in_keys, out_keys, spec, safe, default_interaction_type, distribution_class, distribution_kwargs, return_log_prob, log_prob_key, cache_dist, n_empirical_estimate)
    130 elif spec is not None and not isinstance(spec, Composite):
    131     if len(self.out_keys) > 1:
--> 132         raise RuntimeError(
    133             f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
    134             "Consider using a Composite object or no spec at all."
    135         )
    136     spec = Composite({self.out_keys[0]: spec})
    137 elif spec is not None and isinstance(spec, Composite):

RuntimeError: got more than one out_key for the SafeModule: [('agents', 'agent_0', 'continuous_action'), ('agents', 'agent_1', 'continuous_action'), ('agents', 'agent_2', 'continuous_action'), ('agents', 'agent_3', 'continuous_action'), ('agents', 'agent_4', 'continuous_action'), ('agents', 'agent_5', 'continuous_action'), ('agents', 'agent_6', 'continuous_action'), ('agents', 'agent_7', 'continuous_action'), ('agents', 'agent_8', 'continuous_action'), ('agents', 'agent_9', 'continuous_action'), ('agents', 'agent_10', 'continuous_action'), ('agents', 'agent_11', 'continuous_action')],
but only one spec. Consider using a Composite object or no spec at all.
```</div>
@kurtamohler
Copy link
Collaborator

Hey @kylelevy, thank you for reporting this! You're absolutely right--the agent keys for UnityMLAgentsEnv are different than other multi-agent envs. To fix that, you can apply a Stack transform to the environment. Like this:

>>> import torchrl
>>> env_base = torchrl.envs.UnityMLAgentsEnv(registered_name='3DBall')
>>> t = torchrl.envs.Stack(in_keys=[('group_0', f'agent_{idx}') for idx in range(12)], out_key='agents')
>>> env = torchrl.envs.TransformedEnv(env_base, t)
>>> env.reset()
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                VectorSensor_size8: Tensor(shape=torch.Size([12, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                terminated: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([12, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([12]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

A while back I did make a somewhat janky proof of concept script that uses this to run PPO on a unity env, but haven't had the time to clean it up and add it as a tutorial or example script. When you get yours working, please feel free to open a PR to add it

@kurtamohler
Copy link
Collaborator

Actually, looks like running a rollout in that env raises an error now. I'll take a look at what's wrong and fix it

@kurtamohler
Copy link
Collaborator

Ah, I forgot the Stack transform needs to invert as well. This works:

>>> import torchrl
>>> env_base = torchrl.envs.UnityMLAgentsEnv(registered_name='3DBall')
>>> in_keys = [('group_0', f'agent_{idx}') for idx in range(12)]
>>> out_key = 'agents'
>>> t = torchrl.envs.Stack(in_keys=in_keys, out_key=out_key, in_key_inv=out_key, out_keys_inv=in_keys)
>>> env = torchrl.envs.TransformedEnv(env_base, t)
>>> env.rollout(10)
TensorDict(
    fields={
        agents: TensorDict(
            fields={
                VectorSensor_size8: Tensor(shape=torch.Size([10, 12, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                continuous_action: Tensor(shape=torch.Size([10, 12, 2]), device=cpu, dtype=torch.float32, is_shared=False),
                done: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                terminated: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([10, 12]),
            device=None,
            is_shared=False),
        next: TensorDict(
            fields={
                agents: TensorDict(
                    fields={
                        VectorSensor_size8: Tensor(shape=torch.Size([10, 12, 8]), device=cpu, dtype=torch.float32, is_shared=False),
                        done: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        group_reward: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        reward: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                        terminated: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                        truncated: Tensor(shape=torch.Size([10, 12, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
                    batch_size=torch.Size([10, 12]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([10]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 14, 2025

Here's a link to my proof of concept example: https://github.com/kurtamohler/notes/blob/f2ef281f3eeff2a21a9462e58b77f1c0f1688869/pytorch/torchrl/ppo_multi_agent_notes_unity.ipynb

It took about 1.25 hours to train, although it looks like it reached maximum performance about halfway through it

@kylelevy
Copy link
Author

Firstly thank you so much for taking the time to help me out! I tried downloading the notebook you linked and running it but ran into an error. The VectorSensor_size8 key is not showing up like the notebook says.

I have seen that key before when I was experimenting with deleting agents and running from the editor.

Also is this example for 0.6.0 or 0.7.0? Just so I can try to match my environment to yours.

Image

Interestingly, the rollout does work even though the VectorSensor is not there either, just the continuous_action.

Image

@kurtamohler
Copy link
Collaborator

kurtamohler commented Feb 15, 2025

I'm using torchrl version 0.7.0.

It looks like you should be able to change that observation key to ("agents", "observation_0"), assuming nothing else is different. I guess it's probably different because we're using different versions of mlagents. I have version 1.0.0, and I'm guessing you have 1.1.0

$ pip list | grep mlagents
mlagents                      1.0.0
mlagents_envs                 1.0.0

@kylelevy
Copy link
Author

I was able to reproduce that notebook! Thank you so much for your help. I got it working on a linux environment but cannot reproduce it on windows because of numpy compatibility issues. Not sure if this is in scope but just thought I would mention. I would love to polish the notebook and add some explanations so that I could contribute it as an example for the wrapper! Should I open a new issue to make that pull request or tack it onto this one?

@kurtamohler
Copy link
Collaborator

Great! If/when you open a PR, it can just link to this issue

@kurtamohler
Copy link
Collaborator

One thing to be aware of, if you aren't already, is that the torchrl tutorials are written in python and the notebook is generated using sphinx-gallery. For instance, the PPO tutorial is here, and you can run that Python file directly:
https://github.com/pytorch/rl/blob/a3a1ebefefb339c05d24bfa0c6e1edfcb931c2ac/tutorials/sphinx-tutorials/coding_ppo.py

It's included in the docs here, which causes sphinx to pick it up and build the notebook for it:

tutorials/coding_ppo

@kylelevy
Copy link
Author

kylelevy commented Apr 4, 2025

Hello @kurtamohler, just following up on this issue. I was polishing up some of the example code you provided for this case and ran into a strange behavior I wanted to share with you and get your thoughts on.

Image

When plotting the mean-episode-reward at the end of the training run, there was this weird behavior where it spikes on the 22nd epoch and then flatlines at 100 for the rest of the run. I am not sure if something is saturating or what is going on but it is not a normal behavior I am used to seeing. Furthermore, when comparing with your notebook, I noticed that your episode reward graph was between -1 and just above 0.

Not sure why I am experiencing different behavior or why the mean episode reward behaves like that. The model at the end does behave as expected, just not sure about why this is going on.

If you have any insights, please let me know!

P.S. I will finish cleaning up the example notebook soon with a PR, just having some trouble getting started with sphynx for the first time :)

@kurtamohler
Copy link
Collaborator

kurtamohler commented Apr 4, 2025

I ran into that issue at one point as well. It's an issue with how the reward mean graph data is being calculated, not actually a problem with the training itself. After I first saw that spike, I tried running the model directly, and it never actually failed after running for quite a long time--the balancing heads never dropped the ball.

So the model is actually achieving maximum performance when the reward value is at the flat line on the right side of the graph. For whatever reason, sub-maximum rewards were adding up (rather than averaging) to a value greater than the actual maximum reward value. I don't remember the details of what exactly was wrong (some typo probably) or what I did to fix that, but you could compare how your notebook and mine calculate the rewards to find out

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants