@@ -511,17 +511,28 @@ def _log_weight(
511
511
# current log_prob of actions
512
512
action = _maybe_get_or_select (tensordict , self .tensor_keys .action )
513
513
514
+ is_composite = None
515
+ if all (key in tensordict for key in self .actor_network .dist_params_keys ):
516
+ prev_dist = self .actor_network .build_dist_from_params (tensordict .detach ())
517
+ kwargs , is_composite = _get_composite_kwargs (prev_dist )
518
+ if is_composite :
519
+ prev_log_prob = prev_dist .log_prob (tensordict , ** kwargs )
520
+ else :
521
+ prev_log_prob = prev_dist .log_prob (action , ** kwargs )
522
+ print ('prev_log_prob' , prev_log_prob )
523
+ else :
524
+ try :
525
+ prev_log_prob = _maybe_get_or_select (
526
+ tensordict , self .tensor_keys .sample_log_prob
527
+ )
528
+ except KeyError as err :
529
+ raise _make_lp_get_error (self .tensor_keys , tensordict , err )
530
+
514
531
with self .actor_network_params .to_module (
515
532
self .actor_network
516
533
) if self .functional else contextlib .nullcontext ():
517
- dist = self .actor_network .get_dist (tensordict )
534
+ current_dist = self .actor_network .get_dist (tensordict )
518
535
519
- try :
520
- prev_log_prob = _maybe_get_or_select (
521
- tensordict , self .tensor_keys .sample_log_prob
522
- )
523
- except KeyError as err :
524
- raise _make_lp_get_error (self .tensor_keys , tensordict , err )
525
536
526
537
if prev_log_prob .requires_grad :
527
538
raise RuntimeError (
@@ -532,35 +543,12 @@ def _log_weight(
532
543
raise RuntimeError (
533
544
f"tensordict stored { self .tensor_keys .action } requires grad."
534
545
)
535
- if isinstance (dist , CompositeDistribution ):
536
- is_composite = True
537
- aggregate = dist .aggregate_probabilities
538
- if aggregate is None :
539
- aggregate = False
540
- include_sum = dist .include_sum
541
- if include_sum is None :
542
- include_sum = False
543
- kwargs = {
544
- "inplace" : False ,
545
- "aggregate_probabilities" : aggregate ,
546
- "include_sum" : include_sum ,
547
- }
548
- else :
549
- is_composite = False
550
- kwargs = {}
551
- if not is_composite :
552
- log_prob = dist .log_prob (action )
546
+ if isinstance (action , torch .Tensor ):
547
+ log_prob = current_dist .log_prob (action )
553
548
else :
554
- log_prob : TensorDictBase = dist .log_prob (tensordict , ** kwargs )
555
- if not is_tensor_collection (prev_log_prob ):
556
- # this isn't great, in general multihead actions should have a composite log-prob too
557
- warnings .warn (
558
- "You are using a composite distribution, yet your log-probability is a tensor. "
559
- "This usually happens whenever the CompositeDistribution has aggregate_probabilities=True "
560
- "or include_sum=True. These options should be avoided: leaf log-probs should be written "
561
- "independently and PPO will take care of the aggregation." ,
562
- category = UserWarning ,
563
- )
549
+ if is_composite is None :
550
+ kwargs , is_composite = _get_composite_kwargs (current_dist )
551
+ log_prob : TensorDictBase = current_dist .log_prob (tensordict , ** kwargs )
564
552
if (
565
553
is_composite
566
554
and not is_tensor_collection (prev_log_prob )
@@ -575,7 +563,7 @@ def _log_weight(
575
563
if is_tensor_collection (kl_approx ):
576
564
kl_approx = _sum_td_features (kl_approx )
577
565
578
- return log_weight , dist , kl_approx
566
+ return log_weight , current_dist , kl_approx
579
567
580
568
def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
581
569
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -651,6 +639,9 @@ def _cached_critic_network_params_detached(self):
651
639
@dispatch
652
640
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
653
641
tensordict = tensordict .clone (False )
642
+
643
+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
644
+
654
645
advantage = tensordict .get (self .tensor_keys .advantage , None )
655
646
if advantage is None :
656
647
self .value_estimator (
@@ -664,7 +655,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
664
655
scale = advantage .std ().clamp_min (1e-6 )
665
656
advantage = (advantage - loc ) / scale
666
657
667
- log_weight , dist , kl_approx = self ._log_weight (tensordict )
668
658
if is_tensor_collection (log_weight ):
669
659
log_weight = _sum_td_features (log_weight )
670
660
log_weight = log_weight .view (advantage .shape )
@@ -1306,3 +1296,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
1306
1296
return KeyError (result )
1307
1297
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
1308
1298
return KeyError (result )
1299
+
1300
+ def _get_composite_kwargs (current_dist ):
1301
+ if isinstance (current_dist , CompositeDistribution ):
1302
+ is_composite = True
1303
+ aggregate = current_dist .aggregate_probabilities
1304
+ if aggregate is None :
1305
+ aggregate = False
1306
+ include_sum = current_dist .include_sum
1307
+ if include_sum is None :
1308
+ include_sum = False
1309
+ kwargs = {
1310
+ "inplace" : False ,
1311
+ "aggregate_probabilities" : aggregate ,
1312
+ "include_sum" : include_sum ,
1313
+ }
1314
+ else :
1315
+ is_composite = False
1316
+ kwargs = {}
1317
+ return kwargs , is_composite
0 commit comments