-
Notifications
You must be signed in to change notification settings - Fork 26
/
sample.py
55 lines (40 loc) · 1.72 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn.functional as F
from utils import limit_past, kl, entropy
def sample(model, enc, length, context, temperature=1.0, device='cuda', topk=-1):
assert length > 0
context = torch.tensor(context[-1022:], device=device, dtype=torch.long)
prev = context
output = context
past = None
total_log_probs = 0
total_entropy_ptau = 0
total_num = 0
total_kl = 0 # in bits
with torch.no_grad():
while total_num < length:
if past and past[0].shape[3] >= 1023:
raise RuntimeError
logits, past = model(prev.unsqueeze(0), past=past)
past = limit_past(past)
logits[0, -1, -1] = -1e10 # endoftext can't happen
logits[0, -1, 628] = -1e10 # 2 newlines can't happen
logits, indices = logits[0, -1, :].sort(descending=True)
base_log_probs = F.log_softmax(logits, dim=-1)
if topk > 0:
logits = logits[:topk]
logits = logits / temperature
log_probs = F.log_softmax(logits, dim=-1)
probs = torch.exp(log_probs)
total_kl += kl(probs, log_probs, base_log_probs[:topk])
selection = torch.multinomial(probs, num_samples=1).item()
log_prob_chosen = base_log_probs[selection]
total_log_probs += log_prob_chosen.item()
total_entropy_ptau += entropy(probs, log_probs)
prev = indices[selection].view(1)
output = torch.cat((output, prev))
total_num += 1
avg_NLL = -total_log_probs/total_num
avg_KL = total_kl/total_num
avg_Hq = total_entropy_ptau/total_num
return output[len(context):].tolist(), avg_NLL, avg_KL, avg_Hq