8
8
9
9
from copy import deepcopy
10
10
from dataclasses import dataclass
11
- from typing import Tuple
11
+ from typing import List , Tuple
12
12
13
13
import torch
14
14
from tensordict import (
15
15
is_tensor_collection ,
16
16
TensorDict ,
17
17
TensorDictBase ,
18
18
TensorDictParams ,
19
+ unravel_key ,
19
20
)
20
21
from tensordict .nn import (
21
22
CompositeDistribution ,
33
34
_cache_values ,
34
35
_clip_value_loss ,
35
36
_GAMMA_LMBDA_DEPREC_ERROR ,
37
+ _maybe_add_or_extend_key ,
38
+ _maybe_get_or_select ,
36
39
_reduce ,
37
40
_sum_td_features ,
38
41
default_value_kwargs ,
@@ -67,7 +70,10 @@ class PPOLoss(LossModule):
67
70
68
71
Args:
69
72
actor_network (ProbabilisticTensorDictSequential): policy operator.
70
- critic_network (ValueOperator): value operator.
73
+ Typically a :class:`~tensordict.nn.ProbabilisticTensorDictSequential` subclass taking observations
74
+ as input and outputting an action (or actions) as well as its log-probability value.
75
+ critic_network (ValueOperator): value operator. The critic will usually take the observations as input
76
+ and return a scalar value (``state_value`` by default) in the output keys.
71
77
72
78
Keyword Args:
73
79
entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the
@@ -267,28 +273,28 @@ class _AcceptedKeys:
267
273
Will be used for the underlying value estimator Defaults to ``"value_target"``.
268
274
value (NestedKey): The input tensordict key where the state value is expected.
269
275
Will be used for the underlying value estimator. Defaults to ``"state_value"``.
270
- sample_log_prob (NestedKey): The input tensordict key where the
276
+ sample_log_prob (NestedKey or list of nested keys ): The input tensordict key where the
271
277
sample log probability is expected. Defaults to ``"sample_log_prob"``.
272
- action (NestedKey): The input tensordict key where the action is expected.
278
+ action (NestedKey or list of nested keys ): The input tensordict key where the action is expected.
273
279
Defaults to ``"action"``.
274
- reward (NestedKey): The input tensordict key where the reward is expected.
280
+ reward (NestedKey or list of nested keys ): The input tensordict key where the reward is expected.
275
281
Will be used for the underlying value estimator. Defaults to ``"reward"``.
276
- done (NestedKey): The key in the input TensorDict that indicates
282
+ done (NestedKey or list of nested keys ): The key in the input TensorDict that indicates
277
283
whether a trajectory is done. Will be used for the underlying value estimator.
278
284
Defaults to ``"done"``.
279
- terminated (NestedKey): The key in the input TensorDict that indicates
285
+ terminated (NestedKey or list of nested keys ): The key in the input TensorDict that indicates
280
286
whether a trajectory is terminated. Will be used for the underlying value estimator.
281
287
Defaults to ``"terminated"``.
282
288
"""
283
289
284
290
advantage : NestedKey = "advantage"
285
291
value_target : NestedKey = "value_target"
286
292
value : NestedKey = "state_value"
287
- sample_log_prob : NestedKey = "sample_log_prob"
288
- action : NestedKey = "action"
289
- reward : NestedKey = "reward"
290
- done : NestedKey = "done"
291
- terminated : NestedKey = "terminated"
293
+ sample_log_prob : NestedKey | List [ NestedKey ] = "sample_log_prob"
294
+ action : NestedKey | List [ NestedKey ] = "action"
295
+ reward : NestedKey | List [ NestedKey ] = "reward"
296
+ done : NestedKey | List [ NestedKey ] = "done"
297
+ terminated : NestedKey | List [ NestedKey ] = "terminated"
292
298
293
299
default_keys = _AcceptedKeys ()
294
300
default_value_estimator = ValueEstimators .GAE
@@ -369,7 +375,7 @@ def __init__(
369
375
370
376
try :
371
377
device = next (self .parameters ()).device
372
- except AttributeError :
378
+ except ( AttributeError , StopIteration ) :
373
379
device = torch .device ("cpu" )
374
380
375
381
self .register_buffer ("entropy_coef" , torch .tensor (entropy_coef , device = device ))
@@ -409,15 +415,36 @@ def functional(self):
409
415
410
416
def _set_in_keys (self ):
411
417
keys = [
412
- self .tensor_keys .action ,
413
- self .tensor_keys .sample_log_prob ,
414
- ("next" , self .tensor_keys .reward ),
415
- ("next" , self .tensor_keys .done ),
416
- ("next" , self .tensor_keys .terminated ),
417
418
* self .actor_network .in_keys ,
418
419
* [("next" , key ) for key in self .actor_network .in_keys ],
419
420
* self .critic_network .in_keys ,
420
421
]
422
+
423
+ if isinstance (self .tensor_keys .action , NestedKey ):
424
+ keys .append (self .tensor_keys .action )
425
+ else :
426
+ keys .extend (self .tensor_keys .action )
427
+
428
+ if isinstance (self .tensor_keys .sample_log_prob , NestedKey ):
429
+ keys .append (self .tensor_keys .sample_log_prob )
430
+ else :
431
+ keys .extend (self .tensor_keys .sample_log_prob )
432
+
433
+ if isinstance (self .tensor_keys .reward , NestedKey ):
434
+ keys .append (unravel_key (("next" , self .tensor_keys .reward )))
435
+ else :
436
+ keys .extend ([unravel_key (("next" , k )) for k in self .tensor_keys .reward ])
437
+
438
+ if isinstance (self .tensor_keys .done , NestedKey ):
439
+ keys .append (unravel_key (("next" , self .tensor_keys .done )))
440
+ else :
441
+ keys .extend ([unravel_key (("next" , k )) for k in self .tensor_keys .done ])
442
+
443
+ if isinstance (self .tensor_keys .terminated , NestedKey ):
444
+ keys .append (unravel_key (("next" , self .tensor_keys .terminated )))
445
+ else :
446
+ keys .extend ([unravel_key (("next" , k )) for k in self .tensor_keys .terminated ])
447
+
421
448
self ._in_keys = list (set (keys ))
422
449
423
450
@property
@@ -472,25 +499,38 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
472
499
if is_tensor_collection (entropy ):
473
500
entropy = _sum_td_features (entropy )
474
501
except NotImplementedError :
475
- x = dist .rsample ((self .samples_mc_entropy ,))
502
+ if getattr (dist , "has_rsample" , False ):
503
+ x = dist .rsample ((self .samples_mc_entropy ,))
504
+ else :
505
+ x = dist .sample ((self .samples_mc_entropy ,))
476
506
log_prob = dist .log_prob (x )
477
- if is_tensor_collection (log_prob ):
507
+
508
+ if is_tensor_collection (log_prob ) and isinstance (
509
+ self .tensor_keys .sample_log_prob , NestedKey
510
+ ):
478
511
log_prob = log_prob .get (self .tensor_keys .sample_log_prob )
512
+ else :
513
+ log_prob = log_prob .select (* self .tensor_keys .sample_log_prob )
514
+
479
515
entropy = - log_prob .mean (0 )
480
516
return entropy .unsqueeze (- 1 )
481
517
482
518
def _log_weight (
483
519
self , tensordict : TensorDictBase
484
520
) -> Tuple [torch .Tensor , d .Distribution ]:
521
+
485
522
# current log_prob of actions
486
- action = tensordict . get ( self .tensor_keys .action )
523
+ action = _maybe_get_or_select ( tensordict , self .tensor_keys .action )
487
524
488
525
with self .actor_network_params .to_module (
489
526
self .actor_network
490
527
) if self .functional else contextlib .nullcontext ():
491
528
dist = self .actor_network .get_dist (tensordict )
492
529
493
- prev_log_prob = tensordict .get (self .tensor_keys .sample_log_prob )
530
+ prev_log_prob = _maybe_get_or_select (
531
+ tensordict , self .tensor_keys .sample_log_prob
532
+ )
533
+
494
534
if prev_log_prob .requires_grad :
495
535
raise RuntimeError (
496
536
f"tensordict stored { self .tensor_keys .sample_log_prob } requires grad."
@@ -513,8 +553,8 @@ def _log_weight(
513
553
else :
514
554
is_composite = False
515
555
kwargs = {}
516
- log_prob = dist .log_prob (tensordict , ** kwargs )
517
- if is_composite and not isinstance (prev_log_prob , TensorDict ):
556
+ log_prob : TensorDictBase = dist .log_prob (tensordict , ** kwargs )
557
+ if is_composite and not is_tensor_collection (prev_log_prob ):
518
558
log_prob = _sum_td_features (log_prob )
519
559
log_prob .view_as (prev_log_prob )
520
560
@@ -1088,15 +1128,16 @@ def __init__(
1088
1128
1089
1129
def _set_in_keys (self ):
1090
1130
keys = [
1091
- self .tensor_keys .action ,
1092
- self .tensor_keys .sample_log_prob ,
1093
- ("next" , self .tensor_keys .reward ),
1094
- ("next" , self .tensor_keys .done ),
1095
- ("next" , self .tensor_keys .terminated ),
1096
1131
* self .actor_network .in_keys ,
1097
1132
* [("next" , key ) for key in self .actor_network .in_keys ],
1098
1133
* self .critic_network .in_keys ,
1099
1134
]
1135
+ _maybe_add_or_extend_key (keys , self .tensor_keys .action )
1136
+ _maybe_add_or_extend_key (keys , self .tensor_keys .sample_log_prob )
1137
+ _maybe_add_or_extend_key (keys , self .tensor_keys .reward , "next" )
1138
+ _maybe_add_or_extend_key (keys , self .tensor_keys .done , "next" )
1139
+ _maybe_add_or_extend_key (keys , self .tensor_keys .terminated , "next" )
1140
+
1100
1141
# Get the parameter keys from the actor dist
1101
1142
actor_dist_module = None
1102
1143
for module in self .actor_network .modules ():
0 commit comments