4
4
# LICENSE file in the root directory of this source tree.
5
5
6
6
"""
7
- Multi-head agent and PPO loss
7
+ Multi-head Agent and PPO Loss
8
8
=============================
9
-
10
9
This example demonstrates how to use TorchRL to create a multi-head agent with three separate distributions
11
10
(Gamma, Kumaraswamy, and Mixture) and train it using Proximal Policy Optimization (PPO) losses.
12
11
13
- The code first defines a module `make_params` that extracts the parameters of the distributions from an input tensordict.
14
- It then creates a `dist_constructor` function that takes these parameters as input and outputs a CompositeDistribution
15
- object containing the three distributions.
16
-
17
- The policy is defined as a ProbabilisticTensorDictSequential module that reads an observation, casts it to parameters,
18
- creates a distribution from these parameters, and samples from the distribution to output multiple actions.
19
-
20
- The example tests the policy with fake data across three different PPO losses: PPOLoss, ClipPPOLoss, and KLPENPPOLoss.
21
-
22
- Note that the `log_prob` method of the CompositeDistribution object can return either an aggregated tensor or a
23
- fine-grained tensordict with individual log-probabilities, depending on the value of the `aggregate_probabilities`
24
- argument. The PPO loss modules are designed to handle both cases, and will default to `aggregate_probabilities=False`
25
- if not specified.
26
-
27
- In particular, if `aggregate_probabilities=False` and `include_sum=True`, the summed log-probs will also be included in
28
- the output tensordict. However, since we have access to the individual log-probs, this feature is not typically used.
12
+ Step-by-step Explanation
13
+ ------------------------
14
+
15
+ 1. **Setting Composite Log-Probabilities**:
16
+ - To use composite (=multi-head0 distributions with PPO (or any other algorithm that relies on probability distributions like SAC
17
+ or A2C), you must call `set_composite_lp_aggregate(False).set()`. Not calling this will result in errors during
18
+ execution of your script.
19
+ - From torchrl and tensordict v0.9, this will be the default behavior. Not doing this will result in
20
+ `CompositeDistribution` aggregating the log-probs, which may lead to incorrect log-probabilities.
21
+ - Note that `set_composite_lp_aggregate(False).set()` will cause the sample log-probabilities to be named
22
+ `<action_key>_log_prob` for any probability distribution, not just composite ones. For regular, single-head policies
23
+ for instance, the log-probability will be named `"action_log_prob"`.
24
+ Previously, log-prob keys defaulted to `sample_log_prob`.
25
+ 2. **Action Grouping**:
26
+ - Actions can be grouped or not; PPO doesn't require them to be grouped.
27
+ - If actions are grouped, calling the policy will result in a `TensorDict` with fields for each agent's action and
28
+ log-probability, e.g., `agent0`, `agent0_log_prob`, etc.
29
+
30
+ ... [...]
31
+ ... action: TensorDict(
32
+ ... fields={
33
+ ... agent0: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
34
+ ... agent0_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
35
+ ... agent1: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
36
+ ... agent1_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
37
+ ... agent2: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
38
+ ... agent2_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
39
+ ... batch_size=torch.Size([4]),
40
+ ... device=None,
41
+ ... is_shared=False),
42
+
43
+ - If actions are not grouped, each agent will have its own `TensorDict` with `action` and `action_log_prob` fields.
44
+
45
+ ... [...]
46
+ ... agent0: TensorDict(
47
+ ... fields={
48
+ ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
49
+ ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
50
+ ... batch_size=torch.Size([4]),
51
+ ... device=None,
52
+ ... is_shared=False),
53
+ ... agent1: TensorDict(
54
+ ... fields={
55
+ ... action: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False),
56
+ ... action_log_prob: Tensor(shape=torch.Size([4, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
57
+ ... batch_size=torch.Size([4]),
58
+ ... device=None,
59
+ ... is_shared=False),
60
+ ... agent2: TensorDict(
61
+ ... fields={
62
+ ... action: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False),
63
+ ... action_log_prob: Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, is_shared=False)},
64
+ ... batch_size=torch.Size([4]),
65
+ ... device=None,
66
+ ... is_shared=False),
67
+
68
+ 3. **PPO Loss Calculation**:
69
+ - Under the hood, `ClipPPO` will clip individual weights (not the aggregate) and multiply that by the advantage.
70
+
71
+ The code below sets up a multi-head agent with three distributions and demonstrates how to train it using PPO losses.
29
72
30
73
"""
31
74
38
81
InteractionType ,
39
82
ProbabilisticTensorDictModule as Prob ,
40
83
ProbabilisticTensorDictSequential as ProbSeq ,
84
+ set_composite_lp_aggregate ,
41
85
TensorDictModule as Mod ,
42
86
TensorDictSequential as Seq ,
43
87
WrapModule as Wrap ,
44
88
)
45
89
from torch import distributions as d
46
90
from torchrl .objectives import ClipPPOLoss , KLPENPPOLoss , PPOLoss
47
91
92
+ set_composite_lp_aggregate (False ).set ()
93
+
94
+ GROUPED_ACTIONS = False
95
+
48
96
make_params = Mod (
49
97
lambda : (
50
98
torch .ones (4 ),
@@ -74,8 +122,18 @@ def mixture_constructor(logits, loc, scale):
74
122
)
75
123
76
124
77
- # =============================================================================
78
- # Example 0: aggregate_probabilities=None (default) ===========================
125
+ if GROUPED_ACTIONS :
126
+ name_map = {
127
+ "gamma" : ("action" , "agent0" ),
128
+ "Kumaraswamy" : ("action" , "agent1" ),
129
+ "mixture" : ("action" , "agent2" ),
130
+ }
131
+ else :
132
+ name_map = {
133
+ "gamma" : ("agent0" , "action" ),
134
+ "Kumaraswamy" : ("agent1" , "action" ),
135
+ "mixture" : ("agent2" , "action" ),
136
+ }
79
137
80
138
dist_constructor = functools .partial (
81
139
CompositeDistribution ,
@@ -84,40 +142,27 @@ def mixture_constructor(logits, loc, scale):
84
142
"Kumaraswamy" : d .Kumaraswamy ,
85
143
"mixture" : mixture_constructor ,
86
144
},
87
- name_map = {
88
- "gamma" : ("agent0" , "action" ),
89
- "Kumaraswamy" : ("agent1" , "action" ),
90
- "mixture" : ("agent2" , "action" ),
91
- },
92
- aggregate_probabilities = None ,
145
+ name_map = name_map ,
93
146
)
94
147
95
148
96
149
policy = ProbSeq (
97
150
make_params ,
98
151
Prob (
99
152
in_keys = ["params" ],
100
- out_keys = [( "agent0" , "action" ), ( "agent1" , "action" ), ( "agent2" , "action" )] ,
153
+ out_keys = list ( name_map . values ()) ,
101
154
distribution_class = dist_constructor ,
102
155
return_log_prob = True ,
103
156
default_interaction_type = InteractionType .RANDOM ,
104
157
),
105
158
)
106
159
107
160
td = policy (TensorDict (batch_size = [4 ]))
108
- print ("0. result of policy call" , td )
161
+ print ("Result of policy call" , td )
109
162
110
163
dist = policy .get_dist (td )
111
- log_prob = dist .log_prob (
112
- td , aggregate_probabilities = False , inplace = False , include_sum = False
113
- )
114
- print ("0. non-aggregated log-prob" )
115
-
116
- # We can also get the log-prob from the policy directly
117
- log_prob = policy .log_prob (
118
- td , aggregate_probabilities = False , inplace = False , include_sum = False
119
- )
120
- print ("0. non-aggregated log-prob (from policy)" )
164
+ log_prob = dist .log_prob (td )
165
+ print ("Composite log-prob" , log_prob )
121
166
122
167
# Build a dummy value operator
123
168
value_operator = Seq (
@@ -134,70 +179,12 @@ def mixture_constructor(logits, loc, scale):
134
179
TensorDict (reward = torch .randn (4 , 1 ), done = torch .zeros (4 , 1 , dtype = torch .bool )),
135
180
)
136
181
137
- # Instantiate the loss
182
+ # Instantiate the loss - test the 3 different PPO losses
138
183
for loss_cls in (PPOLoss , ClipPPOLoss , KLPENPPOLoss ):
184
+ # PPO sets the keys automatically by looking at the policy
139
185
ppo = loss_cls (policy , value_operator )
140
-
141
- # Keys are not the default ones - there is more than one action
142
- ppo .set_keys (
143
- action = [("agent0" , "action" ), ("agent1" , "action" ), ("agent2" , "action" )],
144
- sample_log_prob = [
145
- ("agent0" , "action_log_prob" ),
146
- ("agent1" , "action_log_prob" ),
147
- ("agent2" , "action_log_prob" ),
148
- ],
149
- )
150
-
151
- # Get the loss values
152
- loss_vals = ppo (data )
153
- print ("0. " , loss_cls , loss_vals )
154
-
155
-
156
- # ===================================================================
157
- # Example 1: aggregate_probabilities=True ===========================
158
-
159
- dist_constructor .keywords ["aggregate_probabilities" ] = True
160
-
161
- td = policy (TensorDict (batch_size = [4 ]))
162
- print ("1. result of policy call" , td )
163
-
164
- # Instantiate the loss
165
- for loss_cls in (PPOLoss , ClipPPOLoss , KLPENPPOLoss ):
166
- ppo = loss_cls (policy , value_operator )
167
-
168
- # Keys are not the default ones - there is more than one action. No need to indicate the sample-log-prob key, since
169
- # there is only one.
170
- ppo .set_keys (
171
- action = [("agent0" , "action" ), ("agent1" , "action" ), ("agent2" , "action" )]
172
- )
173
-
174
- # Get the loss values
175
- loss_vals = ppo (data )
176
- print ("1. " , loss_cls , loss_vals )
177
-
178
-
179
- # ===================================================================
180
- # Example 2: aggregate_probabilities=False ===========================
181
-
182
- dist_constructor .keywords ["aggregate_probabilities" ] = False
183
-
184
- td = policy (TensorDict (batch_size = [4 ]))
185
- print ("2. result of policy call" , td )
186
-
187
- # Instantiate the loss
188
- for loss_cls in (PPOLoss , ClipPPOLoss , KLPENPPOLoss ):
189
- ppo = loss_cls (policy , value_operator )
190
-
191
- # Keys are not the default ones - there is more than one action
192
- ppo .set_keys (
193
- action = [("agent0" , "action" ), ("agent1" , "action" ), ("agent2" , "action" )],
194
- sample_log_prob = [
195
- ("agent0" , "action_log_prob" ),
196
- ("agent1" , "action_log_prob" ),
197
- ("agent2" , "action_log_prob" ),
198
- ],
199
- )
186
+ print ("tensor keys" , ppo .tensor_keys )
200
187
201
188
# Get the loss values
202
189
loss_vals = ppo (data )
203
- print ("2. " , loss_cls , loss_vals )
190
+ print ("Loss result: " , loss_cls , loss_vals )
0 commit comments