Skip to content

Commit 95d4b04

Browse files
committedMar 17, 2025
[BugFix] Right log-prob size in transformer wrapper
ghstack-source-id: 5226bb4d25bbaaf139b24cf96d096f1d732013d3 Pull Request resolved: pytorch/rl#2854
1 parent 6b0e6c3 commit 95d4b04

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed
 

‎torchrl/data/tensor_specs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4941,7 +4941,7 @@ def set(self, name: str, spec: TensorSpec) -> Composite:
49414941
spec.shape = self.shape
49424942
else:
49434943
raise ValueError(
4944-
f"The shape of the spec {type(spec).__name__} and the Composite {type(self).__name__} mismatch: the first "
4944+
f"The shapes of the spec {type(spec).__name__} and the {type(self).__name__} mismatch: the first "
49454945
f"{self.ndim} dimensions should match but got spec.shape={spec.shape} and "
49464946
f"Composite.shape={self.shape}."
49474947
)

‎torchrl/modules/llm/transformers_policy.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
5353
- "tokens_out", "scores"
5454
5555
"""
56-
# TODO: how do we avoid getting these?
5756
tokens_out = td["tokens_out", "sequences"]
5857
seq_len = tokens_out.shape[1]
5958

6059
del td["tokens_out", "past_key_values"]
60+
6161
scores = dict(td["tokens_out", "scores"].items())
6262
scores = torch.stack(
6363
[scores[str(k)] for k in range(len(scores))], 1
@@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
9090
- "forward", "past_key_values"
9191
- "forward"
9292
"""
93-
# TODO: how do we avoid getting these?
93+
tokens_out = td["tokens_response", "input_ids"]
94+
seq_len = tokens_out.shape[-1]
95+
9496
del td["forward", "past_key_values"]
97+
9598
scores = td["forward", "logits"]
99+
scores = scores[..., -seq_len:, :]
96100
logits = scores - scores.logsumexp(dim=-1, keepdim=True)
97101
td["logits"] = scores
98102
del td["forward"]
99103
scores.shape[1]
100-
tokens = td["tokens_in", "input_ids"]
101-
log_probs = logits.gather(-1, tokens.unsqueeze(-1))
104+
log_probs = logits.gather(-1, tokens_out.unsqueeze(-1))
102105
td["log_probs"] = log_probs
103106
return td
104107

0 commit comments

Comments
 (0)
Please sign in to comment.