@@ -53,11 +53,11 @@ def log_probs_from_scores(td: TensorDictBase) -> TensorDictBase:
53
53
- "tokens_out", "scores"
54
54
55
55
"""
56
- # TODO: how do we avoid getting these?
57
56
tokens_out = td ["tokens_out" , "sequences" ]
58
57
seq_len = tokens_out .shape [1 ]
59
58
60
59
del td ["tokens_out" , "past_key_values" ]
60
+
61
61
scores = dict (td ["tokens_out" , "scores" ].items ())
62
62
scores = torch .stack (
63
63
[scores [str (k )] for k in range (len (scores ))], 1
@@ -90,15 +90,18 @@ def log_probs_from_logits(td: TensorDictBase) -> TensorDictBase:
90
90
- "forward", "past_key_values"
91
91
- "forward"
92
92
"""
93
- # TODO: how do we avoid getting these?
93
+ tokens_out = td ["tokens_response" , "input_ids" ]
94
+ seq_len = tokens_out .shape [- 1 ]
95
+
94
96
del td ["forward" , "past_key_values" ]
97
+
95
98
scores = td ["forward" , "logits" ]
99
+ scores = scores [..., - seq_len :, :]
96
100
logits = scores - scores .logsumexp (dim = - 1 , keepdim = True )
97
101
td ["logits" ] = scores
98
102
del td ["forward" ]
99
103
scores .shape [1 ]
100
- tokens = td ["tokens_in" , "input_ids" ]
101
- log_probs = logits .gather (- 1 , tokens .unsqueeze (- 1 ))
104
+ log_probs = logits .gather (- 1 , tokens_out .unsqueeze (- 1 ))
102
105
td ["log_probs" ] = log_probs
103
106
return td
104
107
0 commit comments