@@ -230,57 +230,70 @@ PPO
230
230
Using PPO with multi-head action policies
231
231
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
232
232
233
+ .. note :: The main tools to consider when building multi-head policies are: :class:`~tensordict.nn.CompositeDistribution`,
234
+ :class: `~tensordict.nn.ProbabilisticTensorDictModule ` and :class: `~tensordict.nn.ProbabilisticTensorDictSequential `.
235
+ When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set() ` at the
236
+ beginning of the script to instruct :class: `~tensordict.nn.CompositeDistribution ` that log-probabilities should not
237
+ be aggregated but rather written as leaves in the tensordict.
238
+
233
239
In some cases, we have a single advantage value but more than one action undertaken. Each action has its own
234
240
log-probability, and shape. For instance, it can be that the action space is structured as follows:
235
241
236
242
>>> action_td = TensorDict(
237
- ... action0= Tensor(batch, n_agents, f0),
238
- ... action1= Tensor(batch, n_agents, f1, f2),
243
+ ... agents= TensorDict(
244
+ ... action0= Tensor(batch, n_agents, f0),
245
+ ... action1= Tensor(batch, n_agents, f1, f2),
246
+ ... batch_size= torch.Size((batch, n_agents))
247
+ ... ),
239
248
... batch_size= torch.Size((batch,))
240
249
... )
241
250
242
251
where `f0 `, `f1 ` and `f2 ` are some arbitrary integers.
243
252
244
- Note that, in TorchRL, the tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
253
+ Note that, in TorchRL, the root tensordict has the shape of the environment (if the environment is batch-locked, otherwise it
245
254
has the shape of the number of batched environments being run). If the tensordict is sampled from the buffer, it will
246
255
also have the shape of the replay buffer `batch_size `. The `n_agent ` dimension, although common to each action, does not
247
- in general appear in the tensordict's batch-size.
256
+ in general appear in the root tensordict's batch-size (although it appears in the sub-tensordict containing the
257
+ agent-specific data according to the :ref: `MARL API <MARL-environment-API >`).
248
258
249
259
There is a legitimate reason why this is the case: the number of agent may condition some but not all the specs of the
250
260
environment. For example, some environments have a shared done state among all agents. A more complete tensordict
251
261
would in this case look like
252
262
253
263
>>> action_td = TensorDict(
254
- ... action0= Tensor(batch, n_agents, f0),
255
- ... action1= Tensor(batch, n_agents, f1, f2),
264
+ ... agents= TensorDict(
265
+ ... action0= Tensor(batch, n_agents, f0),
266
+ ... action1= Tensor(batch, n_agents, f1, f2),
267
+ ... observation= Tensor(batch, n_agents, f3),
268
+ ... batch_size= torch.Size((batch, n_agents))
269
+ ... ),
256
270
... done= Tensor(batch, 1 ),
257
- ... observation= Tensor(batch, n_agents, f3),
258
271
... [... ] # etc
259
272
... batch_size= torch.Size((batch,))
260
273
... )
261
274
262
275
Notice that `done ` states and `reward ` are usually flanked by a rightmost singleton dimension. See this :ref: `part of the doc <reward_done_singleton >`
263
276
to learn more about this restriction.
264
277
265
- The main tools to consider when building multi-head policies are: :class: `~tensordict.nn.CompositeDistribution `,
266
- :class: `~tensordict.nn.ProbabilisticTensorDictModule ` and :class: `~tensordict.nn.ProbabilisticTensorDictSequential `.
267
- When dealing with these, it is recommended to call `tensordict.nn.set_composite_lp_aggregate(False).set() ` at the
268
- beginning of the script to instruct :class: `~tensordict.nn.CompositeDistribution ` that log-probabilities should not
269
- be aggregated but rather written as leaves in the tensordict.
270
-
271
278
The log-probability of our actions given their respective distributions may look like anything like
272
279
273
280
>>> action_td = TensorDict(
274
- ... action0_log_prob= Tensor(batch, n_agents),
275
- ... action1_log_prob= Tensor(batch, n_agents, f1),
281
+ ... agents= TensorDict(
282
+ ... action0_log_prob= Tensor(batch, n_agents),
283
+ ... action1_log_prob= Tensor(batch, n_agents, f1),
284
+ ... batch_size= torch.Size((batch, n_agents))
285
+ ... ),
276
286
... batch_size= torch.Size((batch,))
277
287
... )
278
288
279
289
or
280
290
281
291
>>> action_td = TensorDict(
282
- ... action0_log_prob= Tensor(batch, n_agents),
283
- ... action1_log_prob= Tensor(batch, n_agents),
292
+ ... agents= TensorDict(
293
+ ... action0_log_prob= Tensor(batch, n_agents),
294
+ ... action1_log_prob= Tensor(batch, n_agents),
295
+ ... batch_size= torch.Size((batch, n_agents))
296
+ ... ),
284
297
... batch_size= torch.Size((batch,))
285
298
... )
286
299
@@ -336,7 +349,7 @@ Dreamer
336
349
DreamerValueLoss
337
350
338
351
Multi-agent objectives
339
- -----------------------
352
+ ----------------------
340
353
341
354
.. currentmodule :: torchrl.objectives.multiagent
342
355
0 commit comments