@@ -518,10 +518,7 @@ def _log_weight(
518
518
self .actor_network
519
519
) if self .functional else contextlib .nullcontext ():
520
520
dist = self .actor_network .get_dist (tensordict )
521
- if isinstance (dist , CompositeDistribution ):
522
- is_composite = True
523
- else :
524
- is_composite = False
521
+ is_composite = isinstance (dist , CompositeDistribution )
525
522
526
523
# current log_prob of actions
527
524
if is_composite :
@@ -538,6 +535,32 @@ def _log_weight(
538
535
prev_log_prob = _maybe_get_or_select (
539
536
tensordict , self .tensor_keys .sample_log_prob
540
537
)
538
+ # TODO:
539
+ # # current log_prob of actions
540
+ # action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
541
+ #
542
+ # is_composite = None
543
+ # if all(key in tensordict for key in self.actor_network.dist_params_keys):
544
+ # prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
545
+ # kwargs, is_composite = _get_composite_kwargs(prev_dist)
546
+ # if is_composite:
547
+ # prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
548
+ # else:
549
+ # prev_log_prob = prev_dist.log_prob(action, **kwargs)
550
+ # print('prev_log_prob', prev_log_prob)
551
+ # else:
552
+ # try:
553
+ # prev_log_prob = _maybe_get_or_select(
554
+ # tensordict, self.tensor_keys.sample_log_prob
555
+ # )
556
+ # except KeyError as err:
557
+ # raise _make_lp_get_error(self.tensor_keys, tensordict, err)
558
+
559
+ with self .actor_network_params .to_module (
560
+ self .actor_network
561
+ ) if self .functional else contextlib .nullcontext ():
562
+ current_dist = self .actor_network .get_dist (tensordict )
563
+
541
564
542
565
if prev_log_prob .requires_grad :
543
566
raise RuntimeError (
@@ -558,6 +581,13 @@ def _log_weight(
558
581
"the beginning of your script to get a proper composite log-prob." ,
559
582
category = UserWarning ,
560
583
)
584
+ # TODO:
585
+ # if isinstance(action, torch.Tensor):
586
+ # log_prob = current_dist.log_prob(action)
587
+ # else:
588
+ # if is_composite is None:
589
+ # kwargs, is_composite = _get_composite_kwargs(current_dist)
590
+ # log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
561
591
if (
562
592
is_composite
563
593
and not is_tensor_collection (prev_log_prob )
@@ -571,7 +601,7 @@ def _log_weight(
571
601
if is_tensor_collection (kl_approx ):
572
602
kl_approx = _sum_td_features (kl_approx )
573
603
574
- return log_weight , dist , kl_approx
604
+ return log_weight , current_dist , kl_approx
575
605
576
606
def loss_critic (self , tensordict : TensorDictBase ) -> torch .Tensor :
577
607
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -647,6 +677,9 @@ def _cached_critic_network_params_detached(self):
647
677
@dispatch
648
678
def forward (self , tensordict : TensorDictBase ) -> TensorDictBase :
649
679
tensordict = tensordict .clone (False )
680
+
681
+ log_weight , dist , kl_approx = self ._log_weight (tensordict )
682
+
650
683
advantage = tensordict .get (self .tensor_keys .advantage , None )
651
684
if advantage is None :
652
685
self .value_estimator (
@@ -660,7 +693,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
660
693
scale = advantage .std ().clamp_min (1e-6 )
661
694
advantage = (advantage - loc ) / scale
662
695
663
- log_weight , dist , kl_approx = self ._log_weight (tensordict )
664
696
if is_tensor_collection (log_weight ):
665
697
log_weight = _sum_td_features (log_weight )
666
698
log_weight = log_weight .view (advantage .shape )
0 commit comments