Skip to content

Commit b05d735

Browse files
committed
Update
[ghstack-poisoned]
1 parent f2fa8e9 commit b05d735

File tree

2 files changed

+89
-100
lines changed

2 files changed

+89
-100
lines changed

examples/agents/composite_ppo.py

+87-100
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,71 @@
44
# LICENSE file in the root directory of this source tree.
55

66
"""
7-
Multi-head agent and PPO loss
7+
Multi-head Agent and PPO Loss
88
=============================
9-
109
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
1110
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
1211
13-
The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
14-
It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
15-
object containing the three distributions.
16-
17-
The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
18-
creates a distribution from these parameters, and samples from the distribution to output multiple actions.
19-
20-
The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
21-
22-
Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
23-
fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
24-
argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
25-
if not specified.
26-
27-
In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
28-
the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
12+
Step-by-step Explanation
13+
------------------------
14+
15+
1. **Setting Composite Log-Probabilities**:
16+
- To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC
17+
or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during
18+
execution of your script.
19+
- From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in
20+
`CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities.
21+
- Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named
22+
`<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies
23+
for instance, the log-probability will be named `"action_log_prob"`.
24+
Previously, log-prob keys defaulted to `sample_log_prob`.
25+
2. **Action Grouping**:
26+
- Actions can be grouped or not; PPO doesn't require them to be grouped.
27+
- If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and
28+
log-probability, e.g., `agent0`, `agent0_log_prob`, etc.
29+
30+
... [...]
31+
... action: TensorDict(
32+
... fields={
33+
... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
34+
... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
35+
... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
36+
... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
37+
... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
38+
... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
39+
... batch_size=torch.Size([4]),
40+
... device=None,
41+
... is_shared=False),
42+
43+
- If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields.
44+
45+
... [...]
46+
... agent0: TensorDict(
47+
... fields={
48+
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
49+
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
50+
... batch_size=torch.Size([4]),
51+
... device=None,
52+
... is_shared=False),
53+
... agent1: TensorDict(
54+
... fields={
55+
... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
56+
... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
57+
... batch_size=torch.Size([4]),
58+
... device=None,
59+
... is_shared=False),
60+
... agent2: TensorDict(
61+
... fields={
62+
... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
63+
... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
64+
... batch_size=torch.Size([4]),
65+
... device=None,
66+
... is_shared=False),
67+
68+
3. **PPO Loss Calculation**:
69+
- Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage.
70+
71+
The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses.
2972
3073
"""
3174

@@ -38,13 +81,18 @@
3881
InteractionType,
3982
ProbabilisticTensorDictModule as Prob,
4083
ProbabilisticTensorDictSequential as ProbSeq,
84+
set_composite_lp_aggregate,
4185
TensorDictModule as Mod,
4286
TensorDictSequential as Seq,
4387
WrapModule as Wrap,
4488
)
4589
from torch import distributions as d
4690
from torchrl.objectives import ClipPPOLoss, KLPENPPOLoss, PPOLoss
4791

92+
set_composite_lp_aggregate(False).set()
93+
94+
GROUPED_ACTIONS = False
95+
4896
make_params = Mod(
4997
lambda: (
5098
torch.ones(4),
@@ -74,8 +122,18 @@ def mixture_constructor(logits, loc, scale):
74122
)
75123

76124

77-
# =============================================================================
78-
# Example 0: aggregate_probabilities=None (default) ===========================
125+
if GROUPED_ACTIONS:
126+
name_map = {
127+
"gamma": ("action", "agent0"),
128+
"Kumaraswamy": ("action", "agent1"),
129+
"mixture": ("action", "agent2"),
130+
}
131+
else:
132+
name_map = {
133+
"gamma": ("agent0", "action"),
134+
"Kumaraswamy": ("agent1", "action"),
135+
"mixture": ("agent2", "action"),
136+
}
79137

80138
dist_constructor = functools.partial(
81139
CompositeDistribution,
@@ -84,40 +142,27 @@ def mixture_constructor(logits, loc, scale):
84142
"Kumaraswamy": d.Kumaraswamy,
85143
"mixture": mixture_constructor,
86144
},
87-
name_map={
88-
"gamma": ("agent0", "action"),
89-
"Kumaraswamy": ("agent1", "action"),
90-
"mixture": ("agent2", "action"),
91-
},
92-
aggregate_probabilities=None,
145+
name_map=name_map,
93146
)
94147

95148

96149
policy = ProbSeq(
97150
make_params,
98151
Prob(
99152
in_keys=["params"],
100-
out_keys=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
153+
out_keys=list(name_map.values()),
101154
distribution_class=dist_constructor,
102155
return_log_prob=True,
103156
default_interaction_type=InteractionType.RANDOM,
104157
),
105158
)
106159

107160
td = policy(TensorDict(batch_size=[4]))
108-
print("0. result of policy call", td)
161+
print("Result of policy call", td)
109162

110163
dist = policy.get_dist(td)
111-
log_prob = dist.log_prob(
112-
td, aggregate_probabilities=False, inplace=False, include_sum=False
113-
)
114-
print("0. non-aggregated log-prob")
115-
116-
# We can also get the log-prob from the policy directly
117-
log_prob = policy.log_prob(
118-
td, aggregate_probabilities=False, inplace=False, include_sum=False
119-
)
120-
print("0. non-aggregated log-prob (from policy)")
164+
log_prob = dist.log_prob(td)
165+
print("Composite log-prob", log_prob)
121166

122167
# Build a dummy value operator
123168
value_operator = Seq(
@@ -134,70 +179,12 @@ def mixture_constructor(logits, loc, scale):
134179
TensorDict(reward=torch.randn(4, 1), done=torch.zeros(4, 1, dtype=torch.bool)),
135180
)
136181

137-
# Instantiate the loss
182+
# Instantiate the loss - test the 3 different PPO losses
138183
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
184+
# PPO sets the keys automatically by looking at the policy
139185
ppo = loss_cls(policy, value_operator)
140-
141-
# Keys are not the default ones - there is more than one action
142-
ppo.set_keys(
143-
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
144-
sample_log_prob=[
145-
("agent0", "action_log_prob"),
146-
("agent1", "action_log_prob"),
147-
("agent2", "action_log_prob"),
148-
],
149-
)
150-
151-
# Get the loss values
152-
loss_vals = ppo(data)
153-
print("0. ", loss_cls, loss_vals)
154-
155-
156-
# ===================================================================
157-
# Example 1: aggregate_probabilities=True ===========================
158-
159-
dist_constructor.keywords["aggregate_probabilities"] = True
160-
161-
td = policy(TensorDict(batch_size=[4]))
162-
print("1. result of policy call", td)
163-
164-
# Instantiate the loss
165-
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
166-
ppo = loss_cls(policy, value_operator)
167-
168-
# Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
169-
# there is only one.
170-
ppo.set_keys(
171-
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")]
172-
)
173-
174-
# Get the loss values
175-
loss_vals = ppo(data)
176-
print("1. ", loss_cls, loss_vals)
177-
178-
179-
# ===================================================================
180-
# Example 2: aggregate_probabilities=False ===========================
181-
182-
dist_constructor.keywords["aggregate_probabilities"] = False
183-
184-
td = policy(TensorDict(batch_size=[4]))
185-
print("2. result of policy call", td)
186-
187-
# Instantiate the loss
188-
for loss_cls in (PPOLoss, ClipPPOLoss, KLPENPPOLoss):
189-
ppo = loss_cls(policy, value_operator)
190-
191-
# Keys are not the default ones - there is more than one action
192-
ppo.set_keys(
193-
action=[("agent0", "action"), ("agent1", "action"), ("agent2", "action")],
194-
sample_log_prob=[
195-
("agent0", "action_log_prob"),
196-
("agent1", "action_log_prob"),
197-
("agent2", "action_log_prob"),
198-
],
199-
)
186+
print("tensor keys", ppo.tensor_keys)
200187

201188
# Get the loss values
202189
loss_vals = ppo(data)
203-
print("2. ", loss_cls, loss_vals)
190+
print("Loss result:", loss_cls, loss_vals)

torchrl/objectives/ppo.py

+2
Original file line numberDiff line numberDiff line change
@@ -959,6 +959,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
959959
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
960960
# of the weights.
961961
lw = log_weight.squeeze()
962+
if not isinstance(lw, torch.Tensor):
963+
lw = _sum_td_features(lw)
962964
ess = (2 * lw.logsumexp(0) - (2 * lw).logsumexp(0)).exp()
963965
batch = log_weight.shape[0]
964966

0 commit comments

Comments
 (0)