Skip to content

Commit cc18aac

Browse files
committed
[WIP] Compute lp during loss execution
ghstack-source-id: 8718095 Pull Request resolved: #2688
1 parent fb75d1d commit cc18aac

File tree

2 files changed

+53
-46
lines changed

2 files changed

+53
-46
lines changed

test/test_cost.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -8164,18 +8164,19 @@ def _create_seq_mock_data_ppo(
81648164
obs = total_obs[:, :T]
81658165
next_obs = total_obs[:, 1:]
81668166
if atoms:
8167-
action = torch.randn(batch, T, atoms, action_dim, device=device).clamp(
8168-
-1, 1
8169-
)
8167+
action_shape = (batch, T, atoms, action_dim)
81708168
else:
8171-
action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1)
8169+
action_shape = (batch, T, action_dim)
8170+
params_mean = torch.randn(action_shape, device=device) / 10
8171+
params_scale = torch.rand(action_shape, device=device) / 10
8172+
action = (params_mean + params_scale * torch.randn(action_shape, device=device)).clamp(
8173+
-1, 1
8174+
)
81728175
reward = torch.randn(batch, T, 1, device=device)
81738176
done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81748177
terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device)
81758178
mask = torch.ones(batch, T, dtype=torch.bool, device=device)
81768179
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
8177-
params_mean = torch.randn_like(action) / 10
8178-
params_scale = torch.rand_like(action) / 10
81798180
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
81808181
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
81818182
if sample_log_prob_key is None:
@@ -8201,9 +8202,6 @@ def _create_seq_mock_data_ppo(
82018202
},
82028203
"collector": {"mask": mask},
82038204
action_key: action,
8204-
sample_log_prob_key: (
8205-
torch.randn_like(action[..., 1]) / 10
8206-
).masked_fill_(~mask, 0.0),
82078205
},
82088206
device=device,
82098207
names=[None, "time"],

torchrl/objectives/ppo.py

+46-37
Original file line numberDiff line numberDiff line change
@@ -511,17 +511,28 @@ def _log_weight(
511511
# current log_prob of actions
512512
action = _maybe_get_or_select(tensordict, self.tensor_keys.action)
513513

514+
is_composite = None
515+
if all(key in tensordict for key in self.actor_network.dist_params_keys):
516+
prev_dist = self.actor_network.build_dist_from_params(tensordict.detach())
517+
kwargs, is_composite = _get_composite_kwargs(prev_dist)
518+
if is_composite:
519+
prev_log_prob = prev_dist.log_prob(tensordict, **kwargs)
520+
else:
521+
prev_log_prob = prev_dist.log_prob(action, **kwargs)
522+
print('prev_log_prob', prev_log_prob)
523+
else:
524+
try:
525+
prev_log_prob = _maybe_get_or_select(
526+
tensordict, self.tensor_keys.sample_log_prob
527+
)
528+
except KeyError as err:
529+
raise _make_lp_get_error(self.tensor_keys, tensordict, err)
530+
514531
with self.actor_network_params.to_module(
515532
self.actor_network
516533
) if self.functional else contextlib.nullcontext():
517-
dist = self.actor_network.get_dist(tensordict)
534+
current_dist = self.actor_network.get_dist(tensordict)
518535

519-
try:
520-
prev_log_prob = _maybe_get_or_select(
521-
tensordict, self.tensor_keys.sample_log_prob
522-
)
523-
except KeyError as err:
524-
raise _make_lp_get_error(self.tensor_keys, tensordict, err)
525536

526537
if prev_log_prob.requires_grad:
527538
raise RuntimeError(
@@ -532,35 +543,12 @@ def _log_weight(
532543
raise RuntimeError(
533544
f"tensordict stored {self.tensor_keys.action} requires grad."
534545
)
535-
if isinstance(dist, CompositeDistribution):
536-
is_composite = True
537-
aggregate = dist.aggregate_probabilities
538-
if aggregate is None:
539-
aggregate = False
540-
include_sum = dist.include_sum
541-
if include_sum is None:
542-
include_sum = False
543-
kwargs = {
544-
"inplace": False,
545-
"aggregate_probabilities": aggregate,
546-
"include_sum": include_sum,
547-
}
548-
else:
549-
is_composite = False
550-
kwargs = {}
551-
if not is_composite:
552-
log_prob = dist.log_prob(action)
546+
if isinstance(action, torch.Tensor):
547+
log_prob = current_dist.log_prob(action)
553548
else:
554-
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
555-
if not is_tensor_collection(prev_log_prob):
556-
# this isn't great, in general multihead actions should have a composite log-prob too
557-
warnings.warn(
558-
"You are using a composite distribution, yet your log-probability is a tensor. "
559-
"This usually happens whenever the CompositeDistribution has aggregate_probabilities=True "
560-
"or include_sum=True. These options should be avoided: leaf log-probs should be written "
561-
"independently and PPO will take care of the aggregation.",
562-
category=UserWarning,
563-
)
549+
if is_composite is None:
550+
kwargs, is_composite = _get_composite_kwargs(current_dist)
551+
log_prob: TensorDictBase = current_dist.log_prob(tensordict, **kwargs)
564552
if (
565553
is_composite
566554
and not is_tensor_collection(prev_log_prob)
@@ -575,7 +563,7 @@ def _log_weight(
575563
if is_tensor_collection(kl_approx):
576564
kl_approx = _sum_td_features(kl_approx)
577565

578-
return log_weight, dist, kl_approx
566+
return log_weight, current_dist, kl_approx
579567

580568
def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor:
581569
"""Returns the critic loss multiplied by ``critic_coef``, if it is not ``None``."""
@@ -651,6 +639,9 @@ def _cached_critic_network_params_detached(self):
651639
@dispatch
652640
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
653641
tensordict = tensordict.clone(False)
642+
643+
log_weight, dist, kl_approx = self._log_weight(tensordict)
644+
654645
advantage = tensordict.get(self.tensor_keys.advantage, None)
655646
if advantage is None:
656647
self.value_estimator(
@@ -664,7 +655,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
664655
scale = advantage.std().clamp_min(1e-6)
665656
advantage = (advantage - loc) / scale
666657

667-
log_weight, dist, kl_approx = self._log_weight(tensordict)
668658
if is_tensor_collection(log_weight):
669659
log_weight = _sum_td_features(log_weight)
670660
log_weight = log_weight.view(advantage.shape)
@@ -1306,3 +1296,22 @@ def _make_lp_get_error(tensor_keys, log_prob, err):
13061296
return KeyError(result)
13071297
result += "This is usually due to a missing call to loss.set_keys(sample_log_prob=<list_of_log_prob_keys>)."
13081298
return KeyError(result)
1299+
1300+
def _get_composite_kwargs(current_dist):
1301+
if isinstance(current_dist, CompositeDistribution):
1302+
is_composite = True
1303+
aggregate = current_dist.aggregate_probabilities
1304+
if aggregate is None:
1305+
aggregate = False
1306+
include_sum = current_dist.include_sum
1307+
if include_sum is None:
1308+
include_sum = False
1309+
kwargs = {
1310+
"inplace": False,
1311+
"aggregate_probabilities": aggregate,
1312+
"include_sum": include_sum,
1313+
}
1314+
else:
1315+
is_composite = False
1316+
kwargs = {}
1317+
return kwargs, is_composite

0 commit comments

Comments
 (0)