Skip to content

[Feature] Make PPO compatible with composite actions and log-probs #2665

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 16, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions .github/unittest/linux_sota/scripts/test_sota.py
Original file line number Diff line number Diff line change
@@ -188,19 +188,6 @@
ppo.collector.frames_per_batch=16 \
logger.mode=offline \
logger.backend=
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
""",
"ddpg-single": """python sota-implementations/ddpg/ddpg.py \
collector.total_frames=48 \
@@ -289,6 +276,19 @@
logger.backend=
""",
"bandits": """python sota-implementations/bandits/dqn.py --n_steps=100
""",
"dreamer": """python sota-implementations/dreamer/dreamer.py \
collector.total_frames=600 \
collector.init_random_frames=10 \
collector.frames_per_batch=200 \
env.n_parallel_envs=1 \
optimization.optim_steps_per_batch=1 \
logger.video=False \
logger.backend=csv \
replay_buffer.buffer_size=120 \
replay_buffer.batch_size=24 \
replay_buffer.batch_length=12 \
networks.rssm_hidden_dim=17
""",
}

6 changes: 6 additions & 0 deletions examples/agents/composite_actor.py
Original file line number Diff line number Diff line change
@@ -50,3 +50,9 @@ def forward(self, x):
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))


# TODO:
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
# 3. Using maps in the Actor
190 changes: 190 additions & 0 deletions examples/agents/composite_ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Multi-head Agent and PPO Loss
=============================
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.

Step-by-step Explanation
------------------------

1. **Setting Composite Log-Probabilities**:
- To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC
or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during
execution of your script.
- From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in
`CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities.
- Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named
`<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies
for instance, the log-probability will be named `"action_log_prob"`.
Previously, log-prob keys defaulted to `sample_log_prob`.
2. **Action Grouping**:
- Actions can be grouped or not; PPO doesn't require them to be grouped.
- If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and
log-probability, e.g., `agent0`, `agent0_log_prob`, etc.

... [...]
... action: TensorDict(
... fields={
... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),

- If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields.

... [...]
... agent0: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
... agent1: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),
... agent2: TensorDict(
... fields={
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
... batch_size=torch.Size([4]),
... device=None,
... is_shared=False),

3. **PPO Loss Calculation**:
- Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage.

The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses.

"""

import functools

import torch
from tensordict import TensorDict
from tensordict.nn import (
CompositeDistribution,
InteractionType,
ProbabilisticTensorDictModule as Prob,
ProbabilisticTensorDictSequential as ProbSeq,
set_composite_lp_aggregate,
TensorDictModule as Mod,
TensorDictSequential as Seq,
WrapModule as Wrap,
)
from torch import distributions as d
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss

set_composite_lp_aggregate(False).set()

GROUPED_ACTIONS = False

make_params = Mod(
lambda: (
torch.ones(4),
torch.ones(4),
torch.ones(4, 2),
torch.ones(4, 2),
torch.ones(4, 10) / 10,
torch.zeros(4, 10),
torch.ones(4, 10),
),
in_keys=[],
out_keys=[
("params", "gamma", "concentration"),
("params", "gamma", "rate"),
("params", "Kumaraswamy", "concentration0"),
("params", "Kumaraswamy", "concentration1"),
("params", "mixture", "logits"),
("params", "mixture", "loc"),
("params", "mixture", "scale"),
],
)


def mixture_constructor(logits, loc, scale):
return d.MixtureSameFamily(
d.Categorical(logits=logits), d.Normal(loc=loc, scale=scale)
)


if GROUPED_ACTIONS:
name_map = {
"gamma": ("action", "agent0"),
"Kumaraswamy": ("action", "agent1"),
"mixture": ("action", "agent2"),
}
else:
name_map = {
"gamma": ("agent0", "action"),
"Kumaraswamy": ("agent1", "action"),
"mixture": ("agent2", "action"),
}

dist_constructor = functools.partial(
CompositeDistribution,
distribution_map={
"gamma": d.Gamma,
"Kumaraswamy": d.Kumaraswamy,
"mixture": mixture_constructor,
},
name_map=name_map,
)


policy = ProbSeq(
make_params,
Prob(
in_keys=["params"],
out_keys=list(name_map.values()),
distribution_class=dist_constructor,
return_log_prob=True,
default_interaction_type=InteractionType.RANDOM,
),
)

td = policy(TensorDict(batch_size=[4]))
print("Result of policy call", td)

dist = policy.get_dist(td)
log_prob = dist.log_prob(td)
print("Composite log-prob", log_prob)

# Build a dummy value operator
value_operator = Seq(
Wrap(
lambda td: td.set("state_value", torch.ones((*td.shape, 1))),
out_keys=["state_value"],
)
)

# Create fake data
data = policy(TensorDict(batch_size=[4]))
data.set(
"next",
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
)

# Instantiate the loss - test the 3 different PPO losses
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
# PPO sets the keys automatically by looking at the policy
ppo = loss_cls(policy, value_operator)
print("tensor keys", ppo.tensor_keys)

# Get the loss values
loss_vals = ppo(data)
print("Loss result:", loss_cls, loss_vals)
333 changes: 274 additions & 59 deletions test/test_cost.py

Large diffs are not rendered by default.

7 changes: 4 additions & 3 deletions torchrl/modules/tensordict_module/probabilistic.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

import warnings
from typing import Dict, List, Optional, Type, Union
@@ -86,7 +87,7 @@ class SafeProbabilisticModule(ProbabilisticTensorDictModule):
distribution sample will be written in the tensordict with the key
`'sample_log_prob'`. Default is ``False``.
log_prob_key (NestedKey, optional): key where to write the log_prob if return_log_prob = True.
Defaults to `'sample_log_prob'`.
Defaults to `"action_log_prob"`.
cache_dist (bool, optional): EXPERIMENTAL: if ``True``, the parameters of the
distribution (i.e. the output of the module) will be written to the
tensordict along with the sample. Those parameters can be used to re-compute
@@ -108,7 +109,7 @@ def __init__(
distribution_class: Type = Delta,
distribution_kwargs: Optional[dict] = None,
return_log_prob: bool = False,
log_prob_key: Optional[NestedKey] = "sample_log_prob",
log_prob_key: NestedKey | None = None,
cache_dist: bool = False,
n_empirical_estimate: int = 1000,
):
@@ -140,7 +141,7 @@ def __init__(
elif spec is None:
spec = Composite()
spec_keys = set(unravel_key_list(list(spec.keys(True, True))))
out_keys = set(unravel_key_list(self.out_keys))
out_keys = set(unravel_key_list(self._out_keys))
if spec_keys != out_keys:
# then assume that all the non indicated specs are None
for key in out_keys:
57 changes: 41 additions & 16 deletions torchrl/objectives/a2c.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,14 @@
TensorDictBase,
TensorDictParams,
)
from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import (
composite_lp_aggregate,
CompositeDistribution,
dispatch,
ProbabilisticTensorDictSequential,
set_composite_lp_aggregate,
TensorDictModule,
)
from tensordict.utils import NestedKey
from torch import distributions as d

@@ -240,10 +247,17 @@ class _AcceptedKeys:
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"
sample_log_prob: NestedKey = "sample_log_prob"
sample_log_prob: NestedKey | None = None

def __post_init__(self):
if self.sample_log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.sample_log_prob = "sample_log_prob"
else:
self.sample_log_prob = "action_log_prob"

default_keys = _AcceptedKeys
tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_value_estimator: ValueEstimators = ValueEstimators.GAE

actor_network: TensorDictModule
@@ -353,6 +367,13 @@ def __init__(
else:
self.clip_value = None

log_prob_keys = self.actor_network.log_prob_keys
action_keys = self.actor_network.dist_sample_keys
if len(log_prob_keys) > 1:
self.set_keys(sample_log_prob=log_prob_keys, action=action_keys)
else:
self.set_keys(sample_log_prob=log_prob_keys[0], action=action_keys[0])

@property
def functional(self):
return self._functional
@@ -401,43 +422,47 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
def reset(self) -> None:
pass

@set_composite_lp_aggregate(False)
def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
if HAS_ENTROPY.get(type(dist), False):
entropy = dist.entropy()
else:
x = dist.rsample((self.samples_mc_entropy,))
log_prob = dist.log_prob(x)
if is_tensor_collection(log_prob):
log_prob = log_prob.get(self.tensor_keys.sample_log_prob)
log_prob = sum(log_prob.sum(dim="feature").values(True, True))
entropy = -log_prob.mean(0)
return entropy.unsqueeze(-1)

@set_composite_lp_aggregate(False)
def _log_probs(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:
# current log_prob of actions
action = tensordict.get(self.tensor_keys.action)
tensordict_clone = tensordict.select(
*self.actor_network.in_keys, strict=False
).clone()
).copy()
with self.actor_network_params.to_module(
self.actor_network
) if self.functional else contextlib.nullcontext():
dist = self.actor_network.get_dist(tensordict_clone)
if isinstance(dist, CompositeDistribution):
action_keys = self.tensor_keys.action
action = tensordict.select(
*((action_keys,) if isinstance(action_keys, NestedKey) else action_keys)
)
else:
action = tensordict.get(self.tensor_keys.action)

if action.requires_grad:
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
if isinstance(action, torch.Tensor):
log_prob = dist.log_prob(action)
else:
maybe_log_prob = dist.log_prob(tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
# In some cases (Composite distribution with aggregate_probabilities toggled off) the returned type may not
# be a tensor
log_prob = maybe_log_prob.get(self.tensor_keys.sample_log_prob)
else:
log_prob = maybe_log_prob
log_prob = dist.log_prob(action)
if not isinstance(action, torch.Tensor):
log_prob = sum(
dist.log_prob(tensordict).sum(dim="feature").values(True, True)
)
log_prob = log_prob.unsqueeze(-1)
return log_prob, dist

2 changes: 1 addition & 1 deletion torchrl/objectives/common.py
Original file line number Diff line number Diff line change
@@ -247,7 +247,7 @@ def set_keys(self, **kwargs) -> None:
if value is not None:
setattr(self.tensor_keys, key, value)
else:
setattr(self.tensor_keys, key, self.default_keys.key)
setattr(self.tensor_keys, key, self.default_keys().key)

try:
self._forward_value_estimator_keys(**kwargs)
4 changes: 2 additions & 2 deletions torchrl/objectives/cql.py
Original file line number Diff line number Diff line change
@@ -261,7 +261,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

actor_network: TensorDictModule
@@ -1026,7 +1026,7 @@ class _AcceptedKeys:
pred_val: NestedKey = "pred_val"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_qvalue",
2 changes: 1 addition & 1 deletion torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
@@ -243,7 +243,7 @@ class _AcceptedKeys:
log_prob: NestedKey = "_log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

actor_network: ProbabilisticActor
2 changes: 1 addition & 1 deletion torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
@@ -174,7 +174,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator: ValueEstimators = ValueEstimators.TD0
out_keys = [
"loss_actor",
4 changes: 2 additions & 2 deletions torchrl/objectives/decision_transformer.py
Original file line number Diff line number Diff line change
@@ -71,7 +71,7 @@ class _AcceptedKeys:
action_pred: NestedKey = "action"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys

actor_network: TensorDictModule
actor_network_params: TensorDictParams
@@ -282,7 +282,7 @@ class _AcceptedKeys:
action_pred: NestedKey = "action"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys

actor_network: TensorDictModule
actor_network_params: TensorDictParams
34 changes: 23 additions & 11 deletions torchrl/objectives/deprecated.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@
import torch

from tensordict import TensorDict, TensorDictBase, TensorDictParams
from tensordict.nn import dispatch, TensorDictModule
from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor

@@ -121,14 +121,21 @@ class _AcceptedKeys:
action: NestedKey = "action"
state_action_value: NestedKey = "state_action_value"
value: NestedKey = "state_value"
log_prob: NestedKey = "_log_prob"
log_prob: NestedKey | None = None
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"

def __post_init__(self):
if self.log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.log_prob = "sample_log_prob"
else:
self.log_prob = "action_log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
delay_actor: bool = False
default_value_estimator = ValueEstimators.TD0

@@ -359,12 +366,14 @@ def _actor_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, Tensor]:
tensordict_clone.select(*self.qvalue_network.in_keys, strict=False),
self._cached_detach_qvalue_network_params,
)
state_action_value = tensordict_expand.get("state_action_value").squeeze(-1)
state_action_value = tensordict_expand.get(
self.tensor_keys.state_action_value
).squeeze(-1)
loss_actor = -(
state_action_value
- self.alpha * tensordict_clone.get("sample_log_prob").squeeze(-1)
- self.alpha * tensordict_clone.get(self.tensor_keys.log_prob).squeeze(-1)
)
return loss_actor, tensordict_clone.get("sample_log_prob")
return loss_actor, tensordict_clone.get(self.tensor_keys.log_prob)

def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
tensordict_save = tensordict
@@ -389,30 +398,33 @@ def _qvalue_loss(self, tensordict: TensorDictBase) -> Tensor:
ExplorationType.RANDOM
), self.target_actor_network_params.to_module(self.actor_network):
self.actor_network(next_td)
sample_log_prob = next_td.get("sample_log_prob")
sample_log_prob = next_td.get(self.tensor_keys.log_prob)
# get q-values
next_td = self._vmap_qvalue_networkN0(
next_td,
selected_q_params,
)
state_action_value = next_td.get("state_action_value")
state_action_value = next_td.get(self.tensor_keys.state_action_value)
if (
state_action_value.shape[-len(sample_log_prob.shape) :]
!= sample_log_prob.shape
):
sample_log_prob = sample_log_prob.unsqueeze(-1)
next_state_value = (
next_td.get("state_action_value") - self.alpha * sample_log_prob
next_td.get(self.tensor_keys.state_action_value)
- self.alpha * sample_log_prob
)
next_state_value = next_state_value.min(0)[0]

tensordict.set(("next", "state_value"), next_state_value)
tensordict.set(("next", self.tensor_keys.value), next_state_value)
target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1)
tensordict_expand = self._vmap_qvalue_networkN0(
tensordict.select(*self.qvalue_network.in_keys, strict=False),
self.qvalue_network_params,
)
pred_val = tensordict_expand.get("state_action_value").squeeze(-1)
pred_val = tensordict_expand.get(self.tensor_keys.state_action_value).squeeze(
-1
)
td_error = abs(pred_val - target_value)
loss_qval = distance_loss(
pred_val,
4 changes: 2 additions & 2 deletions torchrl/objectives/dqn.py
Original file line number Diff line number Diff line change
@@ -165,7 +165,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]

@@ -437,7 +437,7 @@ class _AcceptedKeys:
steps_to_next_obs: NestedKey = "steps_to_next_obs"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0

value_network: TensorDictModule
6 changes: 3 additions & 3 deletions torchrl/objectives/dreamer.py
Original file line number Diff line number Diff line change
@@ -90,7 +90,7 @@ class _AcceptedKeys:
reco_pixels: NestedKey = "reco_pixels"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys

decoder: TensorDictModule
reward_model: TensorDictModule
@@ -244,7 +244,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TDLambda

value_model: TensorDictModule
@@ -402,7 +402,7 @@ class _AcceptedKeys:
value: NestedKey = "state_value"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys

value_model: TensorDictModule

2 changes: 1 addition & 1 deletion torchrl/objectives/gail.py
Original file line number Diff line number Diff line change
@@ -60,7 +60,7 @@ class _AcceptedKeys:
discriminator_pred: NestedKey = "d_logits"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys

discriminator_network: TensorDictModule
discriminator_network_params: TensorDictParams
4 changes: 2 additions & 2 deletions torchrl/objectives/iql.py
Original file line number Diff line number Diff line change
@@ -234,7 +234,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
@@ -711,7 +711,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
2 changes: 1 addition & 1 deletion torchrl/objectives/multiagent/qmixer.py
Original file line number Diff line number Diff line change
@@ -180,7 +180,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = ["loss"]

214 changes: 133 additions & 81 deletions torchrl/objectives/ppo.py

Large diffs are not rendered by default.

17 changes: 13 additions & 4 deletions torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams

from tensordict.nn import dispatch, TensorDictModule
from tensordict.nn import composite_lp_aggregate, dispatch, TensorDictModule
from tensordict.utils import NestedKey
from torch import Tensor

@@ -207,7 +207,9 @@ class _AcceptedKeys:
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
action (NestedKey): The input tensordict key where the action is expected. Defaults to ``"action"``.
sample_log_prob (NestedKey): The input tensordict key where the
sample log probability is expected. Defaults to ``"sample_log_prob"``.
sample log probability is expected.
Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
`"action_log_prob"` otherwise.
priority (NestedKey): The input tensordict key where the target
priority is written to. Defaults to ``"td_error"``.
state_action_value (NestedKey): The input tensordict key where the
@@ -224,15 +226,22 @@ class _AcceptedKeys:

action: NestedKey = "action"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
sample_log_prob: NestedKey | None = None
priority: NestedKey = "td_error"
state_action_value: NestedKey = "state_action_value"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"

def __post_init__(self):
if self.sample_log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.sample_log_prob = "sample_log_prob"
else:
self.sample_log_prob = "action_log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
delay_actor: bool = False
default_value_estimator = ValueEstimators.TD0
out_keys = [
21 changes: 17 additions & 4 deletions torchrl/objectives/reinforce.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,12 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams

from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule
from tensordict.nn import (
composite_lp_aggregate,
dispatch,
ProbabilisticTensorDictSequential,
TensorDictModule,
)
from tensordict.utils import NestedKey
from torchrl.objectives.common import LossModule

@@ -189,7 +194,8 @@ class _AcceptedKeys:
value (NestedKey): The input tensordict key where the state value is expected.
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
sample_log_prob (NestedKey): The input tensordict key where the sample log probability is expected.
Defaults to ``"sample_log_prob"``.
Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
`"action_log_prob"` otherwise.
action (NestedKey): The input tensordict key where the action is expected.
Defaults to ``"action"``.
reward (NestedKey): The input tensordict key where the reward is expected.
@@ -205,14 +211,21 @@ class _AcceptedKeys:
advantage: NestedKey = "advantage"
value_target: NestedKey = "value_target"
value: NestedKey = "state_value"
sample_log_prob: NestedKey = "sample_log_prob"
sample_log_prob: NestedKey | None = None
action: NestedKey = "action"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"

def __post_init__(self):
if self.sample_log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.sample_log_prob = "sample_log_prob"
else:
self.sample_log_prob = "action_log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.GAE
out_keys = ["loss_actor", "loss_value"]

47 changes: 32 additions & 15 deletions torchrl/objectives/sac.py
Original file line number Diff line number Diff line change
@@ -15,7 +15,13 @@
import torch
from tensordict import TensorDict, TensorDictBase, TensorDictParams

from tensordict.nn import dispatch, TensorDictModule
from tensordict.nn import (
composite_lp_aggregate,
CompositeDistribution,
dispatch,
set_composite_lp_aggregate,
TensorDictModule,
)
from tensordict.utils import expand_right, NestedKey
from torch import Tensor
from torchrl.data.tensor_specs import Composite, TensorSpec
@@ -46,17 +52,13 @@ def new_func(self, *args, **kwargs):
return new_func


def compute_log_prob(action_dist, action_or_tensordict, tensor_key):
def compute_log_prob(action_dist, action_or_tensordict, tensor_key) -> torch.Tensor:
"""Compute the log probability of an action given a distribution."""
if isinstance(action_or_tensordict, torch.Tensor):
log_p = action_dist.log_prob(action_or_tensordict)
else:
maybe_log_prob = action_dist.log_prob(action_or_tensordict)
if not isinstance(maybe_log_prob, torch.Tensor):
log_p = maybe_log_prob.get(tensor_key)
else:
log_p = maybe_log_prob
return log_p
lp = action_dist.log_prob(action_or_tensordict)
if isinstance(action_dist, CompositeDistribution):
with set_composite_lp_aggregate(False):
return sum(lp.sum(dim="feature").values(True, True))
return lp


class SACLoss(LossModule):
@@ -268,7 +270,8 @@ class _AcceptedKeys:
state_action_value (NestedKey): The input tensordict key where the
state action value is expected. Defaults to ``"state_action_value"``.
log_prob (NestedKey): The input tensordict key where the log probability is expected.
Defaults to ``"sample_log_prob"``.
Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
`"action_log_prob"` otherwise.
priority (NestedKey): The input tensordict key where the target priority is written to.
Defaults to ``"td_error"``.
reward (NestedKey): The input tensordict key where the reward is expected.
@@ -284,14 +287,21 @@ class _AcceptedKeys:
action: NestedKey = "action"
value: NestedKey = "state_value"
state_action_value: NestedKey = "state_action_value"
log_prob: NestedKey = "sample_log_prob"
log_prob: NestedKey | None = None
priority: NestedKey = "td_error"
reward: NestedKey = "reward"
done: NestedKey = "done"
terminated: NestedKey = "terminated"

def __post_init__(self):
if self.log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.log_prob = "sample_log_prob"
else:
self.log_prob = "action_log_prob"

default_keys = _AcceptedKeys
tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_value_estimator = ValueEstimators.TD0

actor_network: TensorDictModule
@@ -426,6 +436,13 @@ def __init__(
self.reduction = reduction
self.skip_done_states = skip_done_states

log_prob_keys = getattr(self.actor_network, "log_prob_keys", [])
action_keys = getattr(self.actor_network, "dist_sample_keys", [])
if len(log_prob_keys) > 1:
self.set_keys(log_prob=log_prob_keys, action=action_keys)
else:
self.set_keys(log_prob=log_prob_keys[0], action=action_keys[0])

def _make_vmap(self):
self._vmap_qnetworkN0 = _vmap_func(
self.qvalue_network, (None, 0), randomness=self.vmap_randomness
@@ -1031,7 +1048,7 @@ class _AcceptedKeys:
log_prob: NestedKey = "log_prob"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
delay_actor: bool = False
out_keys = [
2 changes: 1 addition & 1 deletion torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
@@ -205,7 +205,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
2 changes: 1 addition & 1 deletion torchrl/objectives/td3_bc.py
Original file line number Diff line number Diff line change
@@ -218,7 +218,7 @@ class _AcceptedKeys:
terminated: NestedKey = "terminated"

tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
default_keys = _AcceptedKeys
default_value_estimator = ValueEstimators.TD0
out_keys = [
"loss_actor",
27 changes: 25 additions & 2 deletions torchrl/objectives/utils.py
Original file line number Diff line number Diff line change
@@ -8,10 +8,10 @@
import re
import warnings
from enum import Enum
from typing import Iterable, Optional, Union
from typing import Iterable, List, Optional, Union

import torch
from tensordict import TensorDict, TensorDictBase
from tensordict import NestedKey, TensorDict, TensorDictBase, unravel_key
from tensordict.nn import TensorDictModule
from torch import nn, Tensor
from torch.nn import functional as F
@@ -620,3 +620,26 @@ def group_optimizers(*optimizers: torch.optim.Optimizer) -> torch.optim.Optimize
def _sum_td_features(data: TensorDictBase) -> torch.Tensor:
# Sum all features and return a tensor
return data.sum(dim="feature", reduce=True)


def _maybe_get_or_select(td, key_or_keys):
if isinstance(key_or_keys, (str, tuple)):
return td.get(key_or_keys)
return td.select(*key_or_keys)


def _maybe_add_or_extend_key(
tensor_keys: List[NestedKey],
key_or_list_of_keys: NestedKey | List[NestedKey],
prefix: NestedKey = None,
):
if prefix is not None:
if isinstance(key_or_list_of_keys, NestedKey):
tensor_keys.append(unravel_key((prefix, key_or_list_of_keys)))
else:
tensor_keys.extend([unravel_key((prefix, k)) for k in key_or_list_of_keys])
return
if isinstance(key_or_list_of_keys, NestedKey):
tensor_keys.append(key_or_list_of_keys)
else:
tensor_keys.extend(key_or_list_of_keys)
67 changes: 40 additions & 27 deletions torchrl/objectives/value/advantages.py
Original file line number Diff line number Diff line change
@@ -13,23 +13,29 @@
from typing import Callable, List, Union

import torch
from tensordict import TensorDictBase
from tensordict import is_tensor_collection, TensorDictBase
from tensordict.nn import (
CompositeDistribution,
composite_lp_aggregate,
dispatch,
ProbabilisticTensorDictModule,
set_composite_lp_aggregate,
set_skip_existing,
TensorDictModule,
TensorDictModuleBase,
)
from tensordict.nn.probabilistic import interaction_type
from tensordict.utils import NestedKey
from tensordict.utils import NestedKey, unravel_key
from torch import Tensor

from torchrl._utils import RL_WARNINGS
from torchrl.envs.utils import step_mdp

from torchrl.objectives.utils import _vmap_func, hold_out_net, RANDOM_MODULE_LIST
from torchrl.objectives.utils import (
_maybe_get_or_select,
_vmap_func,
hold_out_net,
RANDOM_MODULE_LIST,
)
from torchrl.objectives.value.functional import (
generalized_advantage_estimate,
td0_return_estimate,
@@ -83,16 +89,9 @@ def _call_actor_net(
log_prob_key: NestedKey,
):
dist = actor_net.get_dist(data.select(*actor_net.in_keys, strict=False))
if isinstance(dist, CompositeDistribution):
kwargs = {
"aggregate_probabilities": True,
"inplace": False,
"include_sum": False,
}
else:
kwargs = {}
s = actor_net._dist_sample(dist, interaction_type=interaction_type())
return dist.log_prob(s, **kwargs)
with set_composite_lp_aggregate(True):
return dist.log_prob(s)


class ValueEstimatorBase(TensorDictModuleBase):
@@ -131,7 +130,9 @@ class _AcceptedKeys:
that indicates the number of steps to the next observation.
Defaults to ``"steps_to_next_obs"``.
sample_log_prob (NestedKey): The key in the input tensordict that
indicates the log probability of the sampled action. Defaults to ``"sample_log_prob"``.
indicates the log probability of the sampled action.
Defaults to ``"sample_log_prob"`` when :func:`~tensordict.nn.composite_lp_aggregate` returns `True`,
`"action_log_prob"` otherwise.
"""

advantage: NestedKey = "advantage"
@@ -141,10 +142,17 @@ class _AcceptedKeys:
done: NestedKey = "done"
terminated: NestedKey = "terminated"
steps_to_next_obs: NestedKey = "steps_to_next_obs"
sample_log_prob: NestedKey = "sample_log_prob"
sample_log_prob: NestedKey | None = None

def __post_init__(self):
if self.sample_log_prob is None:
if composite_lp_aggregate(nowarn=True):
self.sample_log_prob = "sample_log_prob"
else:
self.sample_log_prob = "action_log_prob"

default_keys = _AcceptedKeys
tensor_keys: _AcceptedKeys
default_keys = _AcceptedKeys()
value_network: Union[TensorDictModule, Callable]
_vmap_randomness = None

@@ -294,13 +302,18 @@ def out_keys(self):

def set_keys(self, **kwargs) -> None:
"""Set tensordict key names."""
for key, value in kwargs.items():
if not isinstance(value, (str, tuple)):
for key, value in list(kwargs.items()):
if isinstance(value, list):
value = [unravel_key(k) for k in value]
elif not isinstance(value, (str, tuple)):
if value is None:
raise ValueError("tensordict keys cannot be None")
raise ValueError(
f"key name must be of type NestedKey (Union[str, Tuple[str]]) but got {type(value)}"
)
if value is None:
raise ValueError("tensordict keys cannot be None")
else:
value = unravel_key(value)

if key not in self._AcceptedKeys.__dict__:
raise KeyError(
f"{key} is not an accepted tensordict key for advantages"
@@ -313,8 +326,9 @@ def set_keys(self, **kwargs) -> None:
raise KeyError(
f"value key '{value}' not found in value network out_keys {self.value_network.out_keys}"
)
kwargs[key] = value
if self._tensor_keys is None:
conf = asdict(self.default_keys)
conf = asdict(self.default_keys())
conf.update(self.dep_keys)
else:
conf = asdict(self._tensor_keys)
@@ -1766,12 +1780,11 @@ def forward(
value = tensordict.get(self.tensor_keys.value)
next_value = tensordict.get(("next", self.tensor_keys.value))

# Make sure we have the log prob computed at collection time
if self.tensor_keys.sample_log_prob not in tensordict.keys():
raise ValueError(
f"Expected {self.tensor_keys.sample_log_prob} to be in tensordict"
)
log_mu = tensordict.get(self.tensor_keys.sample_log_prob).view_as(value)
lp = _maybe_get_or_select(tensordict, self.tensor_keys.sample_log_prob)
if is_tensor_collection(lp):
# Sum all values to match the batch size
lp = lp.sum(dim="feature", reduce=True)
log_mu = lp.view_as(value)

# Compute log prob with current policy
with hold_out_net(self.actor_network):