Skip to content

Commit 2811962

Browse files
committed
[WIP] Compute lp during loss execution
ghstack-source-id: f16d93a Pull Request resolved: #2688
1 parent 7575e96 commit 2811962

File tree

2 files changed

+45
-15
lines changed

2 files changed

+45
-15
lines changed

test/test_cost.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -8180,18 +8180,19 @@ def _create_seq_mock_data_ppo(
81808180
obs = total_obs[:, :T]
81818181
next_obs = total_obs[:, 1:]
81828182
if atoms:
8183-
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
8184-
-1, 1
8185-
)
8183+
action_shape = (batch, T, atoms, action_dim)
81868184
else:
8187-
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
8185+
action_shape = (batch, T, action_dim)
8186+
params_mean = torch.randn(action_shape, device=device) / 10
8187+
params_scale = torch.rand(action_shape, device=device) / 10
8188+
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
8189+
-1, 1
8190+
)
81888191
reward = torch.randn(batch, T, 1, device=device)
81898192
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81908193
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81918194
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
81928195
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
8193-
params_mean = torch.randn_like(action) / 10
8194-
params_scale = torch.rand_like(action) / 10
81958196
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81968197
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
81978198
if sample_log_prob_key is None:
@@ -8218,9 +8219,6 @@ def _create_seq_mock_data_ppo(
82188219
},
82198220
"collector": {"mask": mask},
82208221
action_key: action,
8221-
sample_log_prob_key: (
8222-
torch.randn_like(action[..., 1]) / 10
8223-
).masked_fill_(~mask, 0.0),
82248222
},
82258223
device=device,
82268224
names=[None, "time"],

torchrl/objectives/ppo.py

+38-6
Original file line numberDiff line numberDiff line change
@@ -518,10 +518,7 @@ def _log_weight(
518518
self.actor_network
519519
) if self.functional else contextlib.nullcontext():
520520
dist = self.actor_network.get_dist(tensordict)
521-
if isinstance(dist, CompositeDistribution):
522-
is_composite = True
523-
else:
524-
is_composite = False
521+
is_composite = isinstance(dist, CompositeDistribution)
525522

526523
# current log_prob of actions
527524
if is_composite:
@@ -538,6 +535,32 @@ def _log_weight(
538535
prev_log_prob = _maybe_get_or_select(
539536
tensordict, self.tensor_keys.sample_log_prob
540537
)
538+
# TODO:
539+
# # current log_prob of actions
540+
# action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
541+
#
542+
# is_composite = None
543+
# if all(key in tensordict for key in self.actor_network.dist_params_keys):
544+
# prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
545+
# kwargs, is_composite = _get_composite_kwargs(prev_dist)
546+
# if is_composite:
547+
# prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
548+
# else:
549+
# prev_log_prob = prev_dist.log_prob(action, **kwargs)
550+
# print('prev_log_prob', prev_log_prob)
551+
# else:
552+
# try:
553+
# prev_log_prob = _maybe_get_or_select(
554+
# tensordict, self.tensor_keys.sample_log_prob
555+
# )
556+
# except KeyError as err:
557+
# raise _make_lp_get_error(self.tensor_keys, tensordict, err)
558+
559+
with self.actor_network_params.to_module(
560+
self.actor_network
561+
) if self.functional else contextlib.nullcontext():
562+
current_dist = self.actor_network.get_dist(tensordict)
563+
541564

542565
if prev_log_prob.requires_grad:
543566
raise RuntimeError(
@@ -558,6 +581,13 @@ def _log_weight(
558581
"the beginning of your script to get a proper composite log-prob.",
559582
category=UserWarning,
560583
)
584+
# TODO:
585+
# if isinstance(action, torch.Tensor):
586+
# log_prob = current_dist.log_prob(action)
587+
# else:
588+
# if is_composite is None:
589+
# kwargs, is_composite = _get_composite_kwargs(current_dist)
590+
# log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
561591
if (
562592
is_composite
563593
and not is_tensor_collection(prev_log_prob)
@@ -571,7 +601,7 @@ def _log_weight(
571601
if is_tensor_collection(kl_approx):
572602
kl_approx = _sum_td_features(kl_approx)
573603

574-
return log_weight, dist, kl_approx
604+
return log_weight, current_dist, kl_approx
575605

576606
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
577607
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -647,6 +677,9 @@ def _cached_critic_network_params_detached(self):
647677
@dispatch
648678
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
649679
tensordict = tensordict.clone(False)
680+
681+
log_weight, dist, kl_approx = self._log_weight(tensordict)
682+
650683
advantage = tensordict.get(self.tensor_keys.advantage, None)
651684
if advantage is None:
652685
self.value_estimator(
@@ -660,7 +693,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
660693
scale = advantage.std().clamp_min(1e-6)
661694
advantage = (advantage - loc) / scale
662695

663-
log_weight, dist, kl_approx = self._log_weight(tensordict)
664696
if is_tensor_collection(log_weight):
665697
log_weight = _sum_td_features(log_weight)
666698
log_weight = log_weight.view(advantage.shape)

0 commit comments

Comments
 (0)