3
3
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
4
4
# Code: https://github.com/rasbt/LLMs-from-scratch
5
5
6
+ from .utils import KVCache # noqa: F401
7
+
6
8
import torch
7
9
import torch .nn as nn
8
10
11
13
# Chapter 3
12
14
#####################################
13
15
class MultiHeadAttention (nn .Module ):
14
- def __init__ (self , d_in , d_out , context_length , dropout , num_heads , qkv_bias = False , max_seq_len = None , window_size = None ):
16
+ def __init__ (self , d_in , d_out , context_length , dropout , num_heads , qkv_bias = False ):
15
17
super ().__init__ ()
16
18
assert d_out % num_heads == 0 , "d_out must be divisible by num_heads"
17
19
@@ -25,80 +27,41 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
25
27
self .out_proj = nn .Linear (d_out , d_out ) # Linear layer to combine head outputs
26
28
self .dropout = nn .Dropout (dropout )
27
29
28
- ####################################################
29
- # NEW
30
- self .max_seq_len = max_seq_len or context_length
31
- self .window_size = window_size or self .max_seq_len
32
- self .register_buffer ("cache_k" , None , persistent = False )
33
- self .register_buffer ("cache_v" , None , persistent = False )
34
- ####################################################
35
-
36
- def forward (self , x , use_cache = False ):
30
+ def forward (self , x , use_cache = False , start_pos = 0 , cache = None ):
37
31
b , num_tokens , d_in = x .shape
38
32
39
- keys_new = self .W_key (x ) # Shape: (b, num_tokens, d_out)
40
- values_new = self .W_value (x )
33
+ keys = self .W_key (x ) # Shape: (b, num_tokens, d_out)
34
+ values = self .W_value (x )
41
35
queries = self .W_query (x )
42
36
43
37
# We implicitly split the matrix by adding a `num_heads` dimension
44
38
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
45
- keys_new = keys_new .view (b , num_tokens , self .num_heads , self .head_dim )
46
- values_new = values_new .view (b , num_tokens , self .num_heads , self .head_dim )
39
+ keys = keys .view (b , num_tokens , self .num_heads , self .head_dim )
40
+ values = values .view (b , num_tokens , self .num_heads , self .head_dim )
47
41
queries = queries .view (b , num_tokens , self .num_heads , self .head_dim )
48
42
49
43
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
50
- keys_new = keys_new .transpose (1 , 2 )
51
- values_new = values_new .transpose (1 , 2 )
44
+ keys = keys .transpose (1 , 2 )
52
45
queries = queries .transpose (1 , 2 )
46
+ values = values .transpose (1 , 2 )
53
47
54
- ####################################################
55
- # NEW
56
48
if use_cache :
57
- if self .cache_k is None or self .cache_k .size (0 ) != b :
58
- self .cache_k = torch .zeros (b , self .num_heads ,
59
- self .window_size , self .head_dim ,
60
- device = x .device )
61
- self .cache_v = torch .zeros_like (self .cache_k )
62
- self .ptr_cur = 0 # pointer to next free slot
63
-
64
- # if incoming chunk would overflow discard oldest tokens
65
- if self .ptr_cur + num_tokens > self .window_size :
66
- overflow = self .ptr_cur + num_tokens - self .window_size
67
- # shift everything left by `overflow` (cheap view-copy)
68
- self .cache_k [:, :, :- overflow , :] = self .cache_k [:, :, overflow :, :].clone ()
69
- self .cache_v [:, :, :- overflow , :] = self .cache_v [:, :, overflow :, :].clone ()
70
- self .ptr_cur -= overflow # pointer after shift
71
-
72
- self .cache_k [:, :, self .ptr_cur :self .ptr_cur + num_tokens , :] = keys_new
73
- self .cache_v [:, :, self .ptr_cur :self .ptr_cur + num_tokens , :] = values_new
74
- self .ptr_cur += num_tokens
75
-
76
- keys = self .cache_k [:, :, :self .ptr_cur , :]
77
- values = self .cache_v [:, :, :self .ptr_cur , :]
49
+ if cache is not None :
50
+ keys = torch .cat ([cache [0 ], keys ], dim = 2 )
51
+ values = torch .cat ([cache [1 ], values ], dim = 2 )
52
+ next_cache = (keys , values )
78
53
else :
79
- keys , values = keys_new , values_new
80
- self .ptr_cur = 0 # keep pointer sane if you interleave modes
81
- ####################################################
82
- # Compute scaled dot-product attention (aka self-attention) with a causal mask
83
- attn_scores = queries @ keys .transpose (2 , 3 ) # Dot product for each head
54
+ next_cache = None
84
55
85
- ####################################################
86
- # NEW
87
- K = attn_scores . size ( - 1 )
56
+ seq_len = keys . size ( 2 )
57
+ causal_mask = torch . triu ( torch . ones ( seq_len , seq_len , dtype = torch . bool , device = x . device ), diagonal = 1 )
58
+ causal_mask = causal_mask [:, - num_tokens :][ None , None , :, :]
88
59
89
- if num_tokens == K :
90
- # No cache → use the pre‑baked triangular mask slice
91
- causal_mask = torch .triu (torch .ones (num_tokens , K , device = x .device , dtype = torch .bool ), diagonal = 1 )
92
- else :
93
- # Cached: need to offset the diagonal by (K − num_tokens)
94
- offset = K - num_tokens # number of tokens already in cache before this chunk
95
- row_idx = torch .arange (num_tokens , device = x .device ).unsqueeze (1 ) # (num_tokens, 1)
96
- col_idx = torch .arange (K , device = x .device ).unsqueeze (0 ) # (1, K)
97
- causal_mask = row_idx + offset < col_idx # True where j > i+offset
98
- ####################################################
60
+ # Compute scaled dot-product attention (aka self-attention) with a causal mask
61
+ attn_scores = queries @ keys .transpose (2 , 3 ) # Dot product for each head
99
62
100
63
# Use the mask to fill attention scores
101
- attn_scores .masked_fill_ (causal_mask . unsqueeze ( 0 ). unsqueeze ( 0 ) , - torch .inf )
64
+ attn_scores .masked_fill_ (causal_mask , - torch .inf )
102
65
103
66
attn_weights = torch .softmax (attn_scores / keys .shape [- 1 ]** 0.5 , dim = - 1 )
104
67
attn_weights = self .dropout (attn_weights )
@@ -110,13 +73,7 @@ def forward(self, x, use_cache=False):
110
73
context_vec = context_vec .contiguous ().view (b , num_tokens , self .d_out )
111
74
context_vec = self .out_proj (context_vec ) # optional projection
112
75
113
- return context_vec
114
-
115
- ####################################################
116
- # NEW
117
- def reset_cache (self ):
118
- self .cache_k , self .cache_v = None , None
119
- ####################################################
76
+ return context_vec , next_cache
120
77
121
78
122
79
#####################################
@@ -169,25 +126,17 @@ def __init__(self, cfg):
169
126
context_length = cfg ["context_length" ],
170
127
num_heads = cfg ["n_heads" ],
171
128
dropout = cfg ["drop_rate" ],
172
- qkv_bias = cfg ["qkv_bias" ],
173
- window_size = cfg ["kv_window_size" ] if "kv_window_size" in cfg else cfg ["context_length" ] # NEW
174
- )
129
+ qkv_bias = cfg ["qkv_bias" ])
175
130
self .ff = FeedForward (cfg )
176
131
self .norm1 = LayerNorm (cfg ["emb_dim" ])
177
132
self .norm2 = LayerNorm (cfg ["emb_dim" ])
178
133
self .drop_shortcut = nn .Dropout (cfg ["drop_rate" ])
179
134
180
- def forward (self , x , use_cache = False ):
135
+ def forward (self , x , use_cache = False , start_pos = 0 , cache = None ):
181
136
# Shortcut connection for attention block
182
137
shortcut = x
183
138
x = self .norm1 (x )
184
-
185
- # x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
186
- ####################################################
187
- # NEW
188
- x = self .att (x , use_cache = use_cache )
189
- ####################################################
190
-
139
+ x , next_cache = self .att (x , use_cache = use_cache , start_pos = start_pos , cache = cache ) # Shape [batch_size, num_tokens, emb_size]
191
140
x = self .drop_shortcut (x )
192
141
x = x + shortcut # Add the original input back
193
142
@@ -198,7 +147,7 @@ def forward(self, x, use_cache=False):
198
147
x = self .drop_shortcut (x )
199
148
x = x + shortcut # Add the original input back
200
149
201
- return x
150
+ return x , next_cache
202
151
203
152
204
153
class GPTModel (nn .Module ):
@@ -208,80 +157,34 @@ def __init__(self, cfg):
208
157
self .pos_emb = nn .Embedding (cfg ["context_length" ], cfg ["emb_dim" ])
209
158
self .drop_emb = nn .Dropout (cfg ["drop_rate" ])
210
159
211
- # self.trf_blocks = nn.Sequential(
212
- # *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
213
- ####################################################
214
- # NEW
215
- self .trf_blocks = nn .ModuleList (
216
- [TransformerBlock (cfg ) for _ in range (cfg ["n_layers" ])])
217
-
218
- self .ptr_current_pos = 0
219
- ####################################################
160
+ self .trf_blocks = nn .Sequential (
161
+ * [TransformerBlock (cfg ) for _ in range (cfg ["n_layers" ])])
220
162
221
163
self .final_norm = LayerNorm (cfg ["emb_dim" ])
222
164
self .out_head = nn .Linear (cfg ["emb_dim" ], cfg ["vocab_size" ], bias = False )
165
+ self .current_pos = 0
223
166
224
- def forward (self , in_idx , use_cache = False ):
167
+ def forward (self , in_idx , use_cache = False , cache = None ):
225
168
batch_size , seq_len = in_idx .shape
169
+ pos = torch .arange (self .current_pos , self .current_pos + seq_len , device = in_idx .device )
226
170
tok_embeds = self .tok_emb (in_idx )
227
-
228
- # pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
229
-
230
- ####################################################
231
- # NEW
171
+ pos_embeds = self .pos_emb (pos )
172
+ x = self .drop_emb (tok_embeds + pos_embeds )
232
173
233
174
if use_cache :
234
- pos_ids = torch . arange ( self .ptr_current_pos , self . ptr_current_pos + seq_len , device = in_idx . device , dtype = torch . long )
235
- self .ptr_current_pos += seq_len
175
+ start_pos = self .current_pos
176
+ self .current_pos += seq_len
236
177
else :
237
- pos_ids = torch .arange (0 , seq_len , device = in_idx .device , dtype = torch .long )
238
- pos_embeds = self .pos_emb (pos_ids ).unsqueeze (0 )
239
- ####################################################
240
-
241
- x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
242
- x = self .drop_emb (x )
178
+ start_pos = 0
243
179
244
- # x = self.trf_blocks(x)
245
- ####################################################
246
- # NEW
247
- for blk in self .trf_blocks :
248
- x = blk (x , use_cache = use_cache )
249
- ####################################################
180
+ next_cache = []
181
+ for i , block in enumerate (self .trf_blocks ):
182
+ blk_cache = cache .get (i ) if cache else None
183
+ x , new_cache = block (x , use_cache = use_cache , start_pos = start_pos , cache = blk_cache )
184
+ if cache :
185
+ cache .update (i , new_cache )
186
+ next_cache .append (new_cache )
250
187
251
188
x = self .final_norm (x )
252
189
logits = self .out_head (x )
253
190
return logits
254
-
255
- ####################################################
256
- # NEW
257
- def reset_kv_cache (self ):
258
- for blk in self .trf_blocks :
259
- blk .att .reset_cache ()
260
- self .ptr_current_pos = 0
261
- ####################################################
262
-
263
-
264
- def generate_text_simple (model , idx , max_new_tokens , context_size ):
265
- # idx is (B, T) array of indices in the current context
266
- for _ in range (max_new_tokens ):
267
-
268
- # Crop current context if it exceeds the supported context size
269
- # E.g., if LLM supports only 5 tokens, and the context size is 10
270
- # then only the last 5 tokens are used as context
271
- idx_cond = idx [:, - context_size :]
272
-
273
- # Get the predictions
274
- with torch .no_grad ():
275
- logits = model (idx_cond )
276
-
277
- # Focus only on the last time step
278
- # (batch, n_token, vocab_size) becomes (batch, vocab_size)
279
- logits = logits [:, - 1 , :]
280
-
281
- # Get the idx of the vocab entry with the highest logits value
282
- idx_next = torch .argmax (logits , dim = - 1 , keepdim = True ) # (batch, 1)
283
-
284
- # Append sampled index to the running sequence
285
- idx = torch .cat ((idx , idx_next ), dim = 1 ) # (batch, n_tokens+1)
286
-
287
- return idx
0 commit comments