Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Appearance settings

Trying to get UnityMLAgentsEnv example running #2697

Unanswered
kylelevy asked this question in Q&A
Discussion options

Hello there,

I am new to TorchRL and am trying to use it to train a PPO algorithm in Unity-MLAgents. Currently, I am just trying to get a head_balance example scene running but have been having some difficulty using the env as it does not line up with the setup from the other tutorials.

The UnityMLAgentsEnv is working and returns an env with the 12 agents in the scene for the head balance. Like the UnityMLAgentsEnv Docs suggest in their example, each agent is inside one group in the TensorDict and each has its own fields such as continuous_action and the rollout works.

The problem however, is that the keys are not like either the Multiagent PPO Tutorial or the Multiagent DDPG Tutorial and I cannot find an example of how I can go about this format. In both tutorials, the expected keys for the other environments are ('agent', 'action'), ('agent', observation), etc, being that all agents are homogeneous and stacked into one vector right from the environment. The MLAgents head_balance example is not stacked and so I am not sure how to correctly apply the individual agent keys to the Policy or Critic modules.

I have been working on getting this example up and running for a little while and find myself stuck with how to correctly interface this style of environment with the different modules. Could I please get some advice or direction on how to go about this?

P.S. if I can get the head_balance working with TorchRL and the UnityMLAgentsEnv, I would be more than happy to open a pull request and contribute it for others to avoid the same headaches.

Setup:

  • python3.12
  • torchrl==0.6.0
  • tensordict==0.6.1
  • mlagents==0.28.0
  • mlagents-env==0.28.0

Code:

import multiprocessing
import torch
from tensordict.nn import TensorDictModule, TensorDictSequential
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn

from torchrl.collectors import SyncDataCollector
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage
from torchrl.envs import (
    Compose,
    TransformedEnv,
    RewardSum
)
from torchrl.envs import UnityMLAgentsEnv, MarlGroupMapType
from torchrl.envs.utils import check_env_specs
from torchrl.modules import MultiAgentMLP, ProbabilisticActor, TanhNormal, AdditiveGaussianModule
from tqdm import tqdm

# Devices
is_fork = multiprocessing.get_start_method() == "fork"
device = (
    torch.device(0)
    if torch.cuda.is_available() and not is_fork
    else torch.device("cpu")
)

# Sampling
frames_per_batch = 6_000  # Number of team frames collected per training iteration
n_iters = 10  # Number of sampling and training iterations
total_frames = frames_per_batch * n_iters

# Training
num_epochs = 30  # Number of optimization steps per training iteration
minibatch_size = 400  # Size of the mini-batches in each optimization step
lr = 3e-4  # Learning rate
max_grad_norm = 1.0  # Maximum norm for the gradients

# PPO
clip_epsilon = 0.2  # clip value for PPO loss
gamma = 0.99  # discount factor
lmbda = 0.9  # lambda for generalised advantage estimation
entropy_eps = 1e-4  # coefficient of the entropy term in the PPO loss

base_env = UnityMLAgentsEnv(registered_name="3DBall", device=device, group_map=MarlGroupMapType.ALL_IN_ONE_GROUP)

env = TransformedEnv(
    base_env,
    RewardSum(
        in_keys=[key for key in base_env.reward_keys if key[2] == "reward"], # exclude group reward
        reset_keys=base_env.reset_keys
    )
)

check_env_specs(base_env)

n_rollout_steps = 5
rollout = env.rollout(n_rollout_steps)

share_parameters_policy = True

policy_net = nn.Sequential(
    MultiAgentMLP(
        n_agent_inputs=env.observation_spec['agents']['agent_0']['observation_0'].shape[-1],
        n_agent_outputs=env.action_spec['agents']['agent_0']['continuous_action'].shape[-1],
        n_agents=len(env.group_map['agents']),
        centralised=False,
        share_params=share_parameters_policy,
        device=device,
        depth=2,
        num_cells=256,
        activation_class=nn.Tanh
    ),
    NormalParamExtractor(),
)

policy_module = TensorDictModule(
    policy_net, 
    in_keys=[("agents", agent, "observation_0") for agent in env.group_map["agents"]],
    out_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
)

policy = ProbabilisticActor(
    module=policy_module,
    spec=env.full_action_spec["agents", "agent_0", "continuous_action"],
    in_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
    out_keys=[("agents", agent, "continuous_action") for agent in env.group_map["agents"]],
    distribution_class=TanhNormal,
    distribution_kwargs={
        "low": env.action_spec['agents']['agent_0']['continuous_action'].space.low,
        "high": env.action_spec['agents']['agent_0']['continuous_action'].space.high,
    },
    return_log_prob=False,
)

Error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[12], line 1
----> 1 policy = ProbabilisticActor(
      2     module=policy_module,
      3     spec=env.full_action_spec["agents", "agent_0", "continuous_action"],
      4     in_keys=[("agents", agent, "action_param") for agent in env.group_map["agents"]],
      5     out_keys=[("agents", agent, "continuous_action") for agent in env.group_map["agents"]],
      6     distribution_class=TanhNormal,
      7     distribution_kwargs={
      8         "low": env.action_spec['agents']['agent_0']['continuous_action'].space.low,
      9         "high": env.action_spec['agents']['agent_0']['continuous_action'].space.high,
     10     },
     11     return_log_prob=False,
     12 )

File c:\Users\ky097697\Development\distributed-rl-framework\venv\Lib\site-packages\torchrl\modules\tensordict_module\actors.py:390, in ProbabilisticActor.__init__(self, module, in_keys, out_keys, spec, **kwargs)
    385 if len(out_keys) == 1 and spec is not None and not isinstance(spec, Composite):
    386     spec = Composite({out_keys[0]: spec})
    388 super().__init__(
    389     module,
--> 390     SafeProbabilisticModule(
    391         in_keys=in_keys, out_keys=out_keys, spec=spec, **kwargs
    392     ),
    393 )

File c:\Users\ky097697\Development\distributed-rl-framework\venv\Lib\site-packages\torchrl\modules\tensordict_module\probabilistic.py:132, in SafeProbabilisticModule.__init__(self, in_keys, out_keys, spec, safe, default_interaction_type, distribution_class, distribution_kwargs, return_log_prob, log_prob_key, cache_dist, n_empirical_estimate)
    130 elif spec is not None and not isinstance(spec, Composite):
    131     if len(self.out_keys) > 1:
--> 132         raise RuntimeError(
    133             f"got more than one out_key for the SafeModule: {self.out_keys},\nbut only one spec. "
    134             "Consider using a Composite object or no spec at all."
    135         )
    136     spec = Composite({self.out_keys[0]: spec})
    137 elif spec is not None and isinstance(spec, Composite):

RuntimeError: got more than one out_key for the SafeModule: [('agents', 'agent_0', 'continuous_action'), ('agents', 'agent_1', 'continuous_action'), ('agents', 'agent_2', 'continuous_action'), ('agents', 'agent_3', 'continuous_action'), ('agents', 'agent_4', 'continuous_action'), ('agents', 'agent_5', 'continuous_action'), ('agents', 'agent_6', 'continuous_action'), ('agents', 'agent_7', 'continuous_action'), ('agents', 'agent_8', 'continuous_action'), ('agents', 'agent_9', 'continuous_action'), ('agents', 'agent_10', 'continuous_action'), ('agents', 'agent_11', 'continuous_action')],
but only one spec. Consider using a Composite object or no spec at all.
You must be logged in to vote

Replies: 0 comments

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
🙏
Q&A
Labels
None yet
1 participant
Morty Proxy This is a proxified and sanitized view of the page, visit original site.