Skip to content

Commit fe075e4

Browse files
committed
[Feature] Make PPO compatible with composite actions and log-probs
ghstack-source-id: f465f2017843904a510aa06768ced457df987e94 Pull Request resolved: #2665
1 parent d009835 commit fe075e4

File tree

2 files changed

+95
-31
lines changed

2 files changed

+95
-31
lines changed

torchrl/objectives/ppo.py

+70-29
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,15 @@
88

99
from copy import deepcopy
1010
from dataclasses import dataclass
11-
from typing import Tuple
11+
from typing import List, Tuple
1212

1313
import torch
1414
from tensordict import (
1515
is_tensor_collection,
1616
TensorDict,
1717
TensorDictBase,
1818
TensorDictParams,
19+
unravel_key,
1920
)
2021
from tensordict.nn import (
2122
CompositeDistribution,
@@ -33,6 +34,8 @@
3334
_cache_values,
3435
_clip_value_loss,
3536
_GAMMA_LMBDA_DEPREC_ERROR,
37+
_maybe_add_or_extend_key,
38+
_maybe_get_or_select,
3639
_reduce,
3740
_sum_td_features,
3841
default_value_kwargs,
@@ -67,7 +70,10 @@ class PPOLoss(LossModule):
6770
6871
Args:
6972
actor_network (ProbabilisticTensorDictSequential): policy operator.
70-
critic_network (ValueOperator): value operator.
73+
Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations
74+
as input and outputting an action (or actions) as well as its log-probability value.
75+
critic_network (ValueOperator): value operator. The critic will usually take the observations as input
76+
and return a scalar value (``state_value`` by default) in the output keys.
7177
7278
Keyword Args:
7379
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
@@ -267,28 +273,28 @@ class _AcceptedKeys:
267273
Will be used for the underlying value estimator Defaults to ``"value_target"``.
268274
value (NestedKey): The input tensordict key where the state value is expected.
269275
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
270-
sample_log_prob (NestedKey): The input tensordict key where the
276+
sample_log_prob (NestedKey or list of nested keys): The input tensordict key where the
271277
sample log probability is expected. Defaults to ``"sample_log_prob"``.
272-
action (NestedKey): The input tensordict key where the action is expected.
278+
action (NestedKey or list of nested keys): The input tensordict key where the action is expected.
273279
Defaults to ``"action"``.
274-
reward (NestedKey): The input tensordict key where the reward is expected.
280+
reward (NestedKey or list of nested keys): The input tensordict key where the reward is expected.
275281
Will be used for the underlying value estimator. Defaults to ``"reward"``.
276-
done (NestedKey): The key in the input TensorDict that indicates
282+
done (NestedKey or list of nested keys): The key in the input TensorDict that indicates
277283
whether a trajectory is done. Will be used for the underlying value estimator.
278284
Defaults to ``"done"``.
279-
terminated (NestedKey): The key in the input TensorDict that indicates
285+
terminated (NestedKey or list of nested keys): The key in the input TensorDict that indicates
280286
whether a trajectory is terminated. Will be used for the underlying value estimator.
281287
Defaults to ``"terminated"``.
282288
"""
283289

284290
advantage: NestedKey = "advantage"
285291
value_target: NestedKey = "value_target"
286292
value: NestedKey = "state_value"
287-
sample_log_prob: NestedKey = "sample_log_prob"
288-
action: NestedKey = "action"
289-
reward: NestedKey = "reward"
290-
done: NestedKey = "done"
291-
terminated: NestedKey = "terminated"
293+
sample_log_prob: NestedKey | List[NestedKey] = "sample_log_prob"
294+
action: NestedKey | List[NestedKey] = "action"
295+
reward: NestedKey | List[NestedKey] = "reward"
296+
done: NestedKey | List[NestedKey] = "done"
297+
terminated: NestedKey | List[NestedKey] = "terminated"
292298

293299
default_keys = _AcceptedKeys()
294300
default_value_estimator = ValueEstimators.GAE
@@ -369,7 +375,7 @@ def __init__(
369375

370376
try:
371377
device = next(self.parameters()).device
372-
except AttributeError:
378+
except (AttributeError, StopIteration):
373379
device = torch.device("cpu")
374380

375381
self.register_buffer("entropy_coef", torch.tensor(entropy_coef, device=device))
@@ -409,15 +415,36 @@ def functional(self):
409415

410416
def _set_in_keys(self):
411417
keys = [
412-
self.tensor_keys.action,
413-
self.tensor_keys.sample_log_prob,
414-
("next", self.tensor_keys.reward),
415-
("next", self.tensor_keys.done),
416-
("next", self.tensor_keys.terminated),
417418
*self.actor_network.in_keys,
418419
*[("next", key) for key in self.actor_network.in_keys],
419420
*self.critic_network.in_keys,
420421
]
422+
423+
if isinstance(self.tensor_keys.action, NestedKey):
424+
keys.append(self.tensor_keys.action)
425+
else:
426+
keys.extend(self.tensor_keys.action)
427+
428+
if isinstance(self.tensor_keys.sample_log_prob, NestedKey):
429+
keys.append(self.tensor_keys.sample_log_prob)
430+
else:
431+
keys.extend(self.tensor_keys.sample_log_prob)
432+
433+
if isinstance(self.tensor_keys.reward, NestedKey):
434+
keys.append(unravel_key(("next", self.tensor_keys.reward)))
435+
else:
436+
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.reward])
437+
438+
if isinstance(self.tensor_keys.done, NestedKey):
439+
keys.append(unravel_key(("next", self.tensor_keys.done)))
440+
else:
441+
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.done])
442+
443+
if isinstance(self.tensor_keys.terminated, NestedKey):
444+
keys.append(unravel_key(("next", self.tensor_keys.terminated)))
445+
else:
446+
keys.extend([unravel_key(("next", k)) for k in self.tensor_keys.terminated])
447+
421448
self._in_keys = list(set(keys))
422449

423450
@property
@@ -472,25 +499,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
472499
if is_tensor_collection(entropy):
473500
entropy = _sum_td_features(entropy)
474501
except NotImplementedError:
475-
x = dist.rsample((self.samples_mc_entropy,))
502+
if getattr(dist, "has_rsample", False):
503+
x = dist.rsample((self.samples_mc_entropy,))
504+
else:
505+
x = dist.sample((self.samples_mc_entropy,))
476506
log_prob = dist.log_prob(x)
477-
if is_tensor_collection(log_prob):
507+
508+
if is_tensor_collection(log_prob) and isinstance(
509+
self.tensor_keys.sample_log_prob, NestedKey
510+
):
478511
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
512+
else:
513+
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)
514+
479515
entropy = -log_prob.mean(0)
480516
return entropy.unsqueeze(-1)
481517

482518
def _log_weight(
483519
self, tensordict: TensorDictBase
484520
) -> Tuple[torch.Tensor, d.Distribution]:
521+
485522
# current log_prob of actions
486-
action = tensordict.get(self.tensor_keys.action)
523+
action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
487524

488525
with self.actor_network_params.to_module(
489526
self.actor_network
490527
) if self.functional else contextlib.nullcontext():
491528
dist = self.actor_network.get_dist(tensordict)
492529

493-
prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob)
530+
prev_log_prob = _maybe_get_or_select(
531+
tensordict, self.tensor_keys.sample_log_prob
532+
)
533+
494534
if prev_log_prob.requires_grad:
495535
raise RuntimeError(
496536
f"tensordict stored {self.tensor_keys.sample_log_prob} requires grad."
@@ -513,8 +553,8 @@ def _log_weight(
513553
else:
514554
is_composite = False
515555
kwargs = {}
516-
log_prob = dist.log_prob(tensordict, **kwargs)
517-
if is_composite and not isinstance(prev_log_prob, TensorDict):
556+
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
557+
if is_composite and not is_tensor_collection(prev_log_prob):
518558
log_prob = _sum_td_features(log_prob)
519559
log_prob.view_as(prev_log_prob)
520560

@@ -1088,15 +1128,16 @@ def __init__(
10881128

10891129
def _set_in_keys(self):
10901130
keys = [
1091-
self.tensor_keys.action,
1092-
self.tensor_keys.sample_log_prob,
1093-
("next", self.tensor_keys.reward),
1094-
("next", self.tensor_keys.done),
1095-
("next", self.tensor_keys.terminated),
10961131
*self.actor_network.in_keys,
10971132
*[("next", key) for key in self.actor_network.in_keys],
10981133
*self.critic_network.in_keys,
10991134
]
1135+
_maybe_add_or_extend_key(keys, self.tensor_keys.action)
1136+
_maybe_add_or_extend_key(keys, self.tensor_keys.sample_log_prob)
1137+
_maybe_add_or_extend_key(keys, self.tensor_keys.reward, "next")
1138+
_maybe_add_or_extend_key(keys, self.tensor_keys.done, "next")
1139+
_maybe_add_or_extend_key(keys, self.tensor_keys.terminated, "next")
1140+
11001141
# Get the parameter keys from the actor dist
11011142
actor_dist_module = None
11021143
for module in self.actor_network.modules():

torchrl/objectives/utils.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import re
99
import warnings
1010
from enum import Enum
11-
from typing import Iterable, Optional, Union
11+
from typing import Iterable, List, Optional, Union
1212

1313
import torch
14-
from tensordict import TensorDict, TensorDictBase
14+
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
1515
from tensordict.nn import TensorDictModule
1616
from torch import nn, Tensor
1717
from torch.nn import functional as F
@@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize
620620
def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
621621
# Sum all features and return a tensor
622622
return data.sum(dim="feature", reduce=True)
623+
624+
625+
def _maybe_get_or_select(td, key_or_keys):
626+
if isinstance(key_or_keys, (str, tuple)):
627+
return td.get(key_or_keys)
628+
return td.select(*key_or_keys)
629+
630+
631+
def _maybe_add_or_extend_key(
632+
tensor_keys: List[NestedKey],
633+
key_or_list_of_keys: NestedKey | List[NestedKey],
634+
prefix: NestedKey = None,
635+
):
636+
if prefix is not None:
637+
if isinstance(key_or_list_of_keys, NestedKey):
638+
tensor_keys.append(unravel_key((prefix, key_or_list_of_keys)))
639+
else:
640+
tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys])
641+
return
642+
if isinstance(key_or_list_of_keys, NestedKey):
643+
tensor_keys.append(key_or_list_of_keys)
644+
else:
645+
tensor_keys.extend(key_or_list_of_keys)

0 commit comments

Comments
 (0)