-
Notifications
You must be signed in to change notification settings - Fork 0
/
attention.py
48 lines (36 loc) · 1.65 KB
/
attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class CausalSelfAttention(nn.Module):
def __init__(self, config):
super().__init__()
assert config.n_embed % config.n_head == 0
# key, query, value projections for all heads
self.key = nn.Linear(config.n_embed, config.n_embed)
self.query = nn.Linear(config.n_embed, config.n_embed)
self.value = nn.Linear(config.n_embed, config.n_embed)
# regularization
self.attn_drop = nn.Dropout(config.attn_pdrop)
self.resid_drop = nn.Dropout(config.resid_pdrop)
# output projection
self.proj = nn.Linear(config.n_embed, config.n_embed)
# causal mask
self.register_buffer('mask', torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))
self.n_head = config.n_head
def forward(self, x, layer_past=None):
B, T, C = x.size()
# calculate query, key, values for heads
k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
# causal self-attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.mask[:, :, :T, :T] == 0, -1e10)
att = F.softmax(att, dim=-1)
att = self.attn_drop(att)
# output projection
y = att @ v
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.resid_drop(self.proj(y))
return y