Skip to content

Commit ad7d2a1

Browse files
committed
[Minor] Fix doc and MARL tests
ghstack-source-id: 9308be3ebc7fac30b5bde321792eb97069d55996 Pull Request resolved: #2759
1 parent cb37521 commit ad7d2a1

File tree

3 files changed

+68
-39
lines changed

3 files changed

+68
-39
lines changed

docs/source/reference/objectives.rst

+31-18
Original file line numberDiff line numberDiff line change
@@ -230,57 +230,70 @@ PPO
230230
Using PPO with multi-head action policies
231231
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
232232

233+
.. note:: The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
234+
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
235+
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
236+
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
237+
be aggregated but rather written as leaves in the tensordict.
238+
233239
In some cases, we have a single advantage value but more than one action undertaken. Each action has its own
234240
log-probability, and shape. For instance, it can be that the action space is structured as follows:
235241

236242
>>> action_td = TensorDict(
237-
... action0=Tensor(batch, n_agents, f0),
238-
... action1=Tensor(batch, n_agents, f1, f2),
243+
... agents=TensorDict(
244+
... action0=Tensor(batch, n_agents, f0),
245+
... action1=Tensor(batch, n_agents, f1, f2),
246+
... batch_size=torch.Size((batch, n_agents))
247+
... ),
239248
... batch_size=torch.Size((batch,))
240249
... )
241250

242251
where `f0`, `f1` and `f2` are some arbitrary integers.
243252

244-
Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
253+
Note that, in TorchRL, the root tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
245254
has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will
246255
also have the shape of the replay buffer `batch_size`. The `n_agent` dimension, although common to each action, does not
247-
in general appear in the tensordict's batch-size.
256+
in general appear in the root tensordict's batch-size (although it appears in the sub-tensordict containing the
257+
agent-specific data according to the :ref:`MARL API <MARL-environment-API>`).
248258

249259
There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the
250260
environment. For example, some environments have a shared done state among all agents. A more complete tensordict
251261
would in this case look like
252262

253263
>>> action_td = TensorDict(
254-
... action0=Tensor(batch, n_agents, f0),
255-
... action1=Tensor(batch, n_agents, f1, f2),
264+
... agents=TensorDict(
265+
... action0=Tensor(batch, n_agents, f0),
266+
... action1=Tensor(batch, n_agents, f1, f2),
267+
... observation=Tensor(batch, n_agents, f3),
268+
... batch_size=torch.Size((batch, n_agents))
269+
... ),
256270
... done=Tensor(batch, 1),
257-
... observation=Tensor(batch, n_agents, f3),
258271
... [...] # etc
259272
... batch_size=torch.Size((batch,))
260273
... )
261274

262275
Notice that `done` states and `reward` are usually flanked by a rightmost singleton dimension. See this :ref:`part of the doc <reward_done_singleton>`
263276
to learn more about this restriction.
264277

265-
The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
266-
:class:`~tensordict.nn.ProbabilisticTensorDictModule` and :class:`~tensordict.nn.ProbabilisticTensorDictSequential`.
267-
When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set()` at the
268-
beginning of the script to instruct :class:`~tensordict.nn.CompositeDistribution` that log-probabilities should not
269-
be aggregated but rather written as leaves in the tensordict.
270-
271278
The log-probability of our actions given their respective distributions may look like anything like
272279

273280
>>> action_td = TensorDict(
274-
... action0_log_prob=Tensor(batch, n_agents),
275-
... action1_log_prob=Tensor(batch, n_agents, f1),
281+
... agents=TensorDict(
282+
... action0_log_prob=Tensor(batch, n_agents),
283+
... action1_log_prob=Tensor(batch, n_agents, f1),
284+
... batch_size=torch.Size((batch, n_agents))
285+
... ),
276286
... batch_size=torch.Size((batch,))
277287
... )
278288

279289
or
280290

281291
>>> action_td = TensorDict(
282-
... action0_log_prob=Tensor(batch, n_agents),
283-
... action1_log_prob=Tensor(batch, n_agents),
292+
... agents=TensorDict(
293+
... action0_log_prob=Tensor(batch, n_agents),
294+
... action1_log_prob=Tensor(batch, n_agents),
295+
... batch_size=torch.Size((batch, n_agents))
296+
... ),
284297
... batch_size=torch.Size((batch,))
285298
... )
286299

@@ -336,7 +349,7 @@ Dreamer
336349
DreamerValueLoss
337350

338351
Multi-agent objectives
339-
-----------------------
352+
----------------------
340353

341354
.. currentmodule:: torchrl.objectives.multiagent
342355

test/test_cost.py

+36-20
Original file line numberDiff line numberDiff line change
@@ -209,45 +209,54 @@ def __init__(self):
209209
self.obs_feat = obs_feat = (5,)
210210

211211
self.full_observation_spec = Composite(
212-
observation=Unbounded(batch + n_agents + obs_feat),
213-
batch_size=batch,
212+
agents=Composite(
213+
observation=Unbounded(batch + n_agents + obs_feat),
214+
shape=batch + n_agents,
215+
),
216+
shape=batch,
214217
)
215218
self.full_done_spec = Composite(
216219
done=Unbounded(batch + (1,), dtype=torch.bool),
217220
terminated=Unbounded(batch + (1,), dtype=torch.bool),
218221
truncated=Unbounded(batch + (1,), dtype=torch.bool),
219-
batch_size=batch,
222+
shape=batch,
220223
)
221224

222-
self.act_feat_dirich = act_feat_dirich = (
223-
10,
224-
2,
225-
)
225+
self.act_feat_dirich = act_feat_dirich = (10, 2)
226226
self.act_feat_categ = act_feat_categ = (7,)
227227
self.full_action_spec = Composite(
228-
dirich=Unbounded(batch + n_agents + act_feat_dirich),
229-
categ=Unbounded(batch + n_agents + act_feat_categ),
230-
batch_size=batch,
228+
agents=Composite(
229+
dirich=Unbounded(batch + n_agents + act_feat_dirich),
230+
categ=Unbounded(batch + n_agents + act_feat_categ),
231+
shape=batch + n_agents,
232+
),
233+
shape=batch,
231234
)
232235

233236
self.full_reward_spec = Composite(
234-
reward=Unbounded(batch + n_agents + (1,)), batch_size=batch
237+
agents=Composite(
238+
reward=Unbounded(batch + n_agents + (1,)), shape=batch + n_agents
239+
),
240+
shape=batch,
235241
)
236242

237243
@classmethod
238244
def make_composite_dist(cls):
239245
dist_cstr = functools.partial(
240246
CompositeDistribution,
241247
distribution_map={
242-
"dirich": lambda concentration: torch.distributions.Independent(
248+
(
249+
"agents",
250+
"dirich",
251+
): lambda concentration: torch.distributions.Independent(
243252
torch.distributions.Dirichlet(concentration), 1
244253
),
245-
"categ": torch.distributions.Categorical,
254+
("agents", "categ"): torch.distributions.Categorical,
246255
},
247256
)
248257
return ProbabilisticTensorDictModule(
249258
in_keys=["params"],
250-
out_keys=["dirich", "categ"],
259+
out_keys=[("agents", "dirich"), ("agents", "categ")],
251260
distribution_class=dist_cstr,
252261
return_log_prob=True,
253262
)
@@ -9309,8 +9318,13 @@ def test_ppo_marl_aggregate(self):
93099318

93109319
def primer(td):
93119320
params = TensorDict(
9312-
dirich=TensorDict(concentration=env.action_spec["dirich"].one()),
9313-
categ=TensorDict(logits=env.action_spec["categ"].one()),
9321+
agents=TensorDict(
9322+
dirich=TensorDict(
9323+
concentration=env.action_spec["agents", "dirich"].one()
9324+
),
9325+
categ=TensorDict(logits=env.action_spec["agents", "categ"].one()),
9326+
batch_size=env.action_spec["agents"].shape,
9327+
),
93149328
batch_size=td.batch_size,
93159329
)
93169330
td.set("params", params)
@@ -9323,11 +9337,13 @@ def primer(td):
93239337
)
93249338
output = policy(env.fake_tensordict())
93259339
assert output.shape == env.batch_size
9326-
assert output["dirich_log_prob"].shape == env.batch_size + env.n_agents
9327-
assert output["categ_log_prob"].shape == env.batch_size + env.n_agents
9340+
assert (
9341+
output["agents", "dirich_log_prob"].shape == env.batch_size + env.n_agents
9342+
)
9343+
assert output["agents", "categ_log_prob"].shape == env.batch_size + env.n_agents
93289344

9329-
output["advantage"] = output["next", "reward"].clone()
9330-
output["value_target"] = output["next", "reward"].clone()
9345+
output["advantage"] = output["next", "agents", "reward"].clone()
9346+
output["value_target"] = output["next", "agents", "reward"].clone()
93319347
critic = TensorDictModule(
93329348
lambda obs: obs.new_zeros((*obs.shape[:-1], 1)),
93339349
in_keys=list(env.full_observation_spec.keys(True, True)),

torchrl/envs/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3574,7 +3574,7 @@ def fake_tensordict(self) -> TensorDictBase:
35743574
observation_spec = self.observation_spec
35753575
action_spec = self.input_spec["full_action_spec"]
35763576
# instantiates reward_spec if needed
3577-
_ = self.reward_spec
3577+
_ = self.full_reward_spec
35783578
reward_spec = self.output_spec["full_reward_spec"]
35793579
full_done_spec = self.output_spec["full_done_spec"]
35803580

0 commit comments

Comments
 (0)