Skip to content

Commit ba0370a

Browse files
authoredJun 15, 2025··
Optimized KV cache (#672)
* Optimized KV cache * typo fix
1 parent 2af686d commit ba0370a

File tree

2 files changed

+441
-0
lines changed

2 files changed

+441
-0
lines changed
 

‎ch04/03_kv-cache/README.md

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,3 +218,64 @@ As sequence length increases, the benefits and downsides of a KV cache become mo
218218

219219

220220

221+
 
222+
## Optimizing the KV Cache Implementation
223+
224+
While my conceptual implementation of a KV cache above helps with clarity and is mainly geared towards code readability and educational purposes, deploying it in real-world scenarios (especially with larger models and longer sequence lengths) requires more careful optimization.
225+
226+
 
227+
### Common pitfalls when scaling the cache
228+
229+
- **Memory fragmentation and repeated allocations**: Continuously concatenating tensors via `torch.cat` as shown earlier, leads to performance bottlenecks due to frequent memory allocation and reallocation.
230+
231+
- **Linear growth in memory usage**: Without proper handling, the KV cache size becomes impractical for very long sequences.
232+
233+
 
234+
#### Tip 1: Pre-allocate Memory
235+
236+
Rather than concatenating tensors repeatedly, we could pre-allocate a sufficiently large tensor based on the expected maximum sequence length. This ensures consistent memory use and reduces overhead. In pseudo-code, this may look like as follows:
237+
238+
```python
239+
# Example pre-allocation for keys and values
240+
max_seq_len = 1024 # maximum expected sequence length
241+
cache_k = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device)
242+
cache_v = torch.zeros((batch_size, num_heads, max_seq_len, head_dim), device=device)
243+
```
244+
245+
During inference, we can then simply write into slices of these pre-allocated tensors.
246+
247+
 
248+
#### Tip 2: Truncate Cache via Sliding Window
249+
250+
To avoid blowing up our GPU memory, we can implement a sliding window approach with dynamic truncation. Via the sliding window, we maintain only the last `window_size` tokens in the cache:
251+
252+
253+
```python
254+
# Sliding window cache implementation
255+
window_size = 512
256+
cache_k = cache_k[:, :, -window_size:, :]
257+
cache_v = cache_v[:, :, -window_size:, :]
258+
```
259+
260+
 
261+
#### Optimizations in practice
262+
263+
You can find these optimizations in the [`gpt_with_kv_cache_optimized.py`](gpt_with_kv_cache_optimized.py) file.
264+
265+
266+
On a Mac Mini with an M4 chip (CPU), with a 200-token generation and a window size of 48 below, the code runtimes compare as follows:
267+
268+
| | Tokens/sec |
269+
| -------------------------------- | ---------- |
270+
| `gpt_ch04.py` | 27 |
271+
| `gpt_with_kv_cache.py` | 110 |
272+
| `gpt_with_kv_cache_optimized.py` | 148 |
273+
274+
Unfortunately, the speed advantages disappear on CUDA devices as this is a tiny model, and the device transfer and communication outweigh the benefits of a KV cache for this small model. However, we can see a significant difference in the memory usage:
275+
276+
| | RAM |
277+
| -------------------------------- | ------- |
278+
| `gpt_ch04.py` | 0.74 GB |
279+
| `gpt_with_kv_cache.py` | 4.35 GB |
280+
| `gpt_with_kv_cache_optimized.py` | 0.89 GB |
281+
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
# This file collects all the relevant code that we covered thus far
2+
# throughout Chapters 3-4.
3+
# This file can be run as a standalone script.
4+
5+
import time
6+
import tiktoken
7+
import torch
8+
import torch.nn as nn
9+
10+
11+
#####################################
12+
# Chapter 3
13+
#####################################
14+
class MultiHeadAttention(nn.Module):
15+
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
16+
super().__init__()
17+
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
18+
19+
self.d_out = d_out
20+
self.num_heads = num_heads
21+
self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim
22+
23+
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
24+
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
25+
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
26+
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
27+
self.dropout = nn.Dropout(dropout)
28+
29+
####################################################
30+
# NEW
31+
self.max_seq_len = max_seq_len or context_length
32+
self.window_size = window_size or self.max_seq_len
33+
self.register_buffer("cache_k", None, persistent=False)
34+
self.register_buffer("cache_v", None, persistent=False)
35+
####################################################
36+
37+
def forward(self, x, use_cache=False):
38+
b, num_tokens, d_in = x.shape
39+
40+
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
41+
values_new = self.W_value(x)
42+
queries = self.W_query(x)
43+
44+
# We implicitly split the matrix by adding a `num_heads` dimension
45+
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
46+
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
47+
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
48+
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
49+
50+
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
51+
keys_new = keys_new.transpose(1, 2)
52+
values_new = values_new.transpose(1, 2)
53+
queries = queries.transpose(1, 2)
54+
55+
####################################################
56+
# NEW
57+
if use_cache:
58+
if self.cache_k is None or self.cache_k.size(0) != b:
59+
self.cache_k = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
60+
self.cache_v = torch.zeros(b, self.num_heads, self.max_seq_len, self.head_dim, device=x.device)
61+
self.current_pos = 0
62+
63+
# write new entries
64+
start = self.current_pos
65+
end = start + num_tokens
66+
self.cache_k[:, :, start:end, :] = keys_new
67+
self.cache_v[:, :, start:end, :] = values_new
68+
self.current_pos = end
69+
70+
# sliding window truncation
71+
if self.current_pos > self.window_size:
72+
self.cache_k = self.cache_k[:, :, -self.window_size:, :]
73+
self.cache_v = self.cache_v[:, :, -self.window_size:, :]
74+
self.current_pos = self.window_size
75+
76+
keys = self.cache_k[:, :, :self.current_pos, :]
77+
values = self.cache_v[:, :, :self.current_pos, :]
78+
else:
79+
keys = keys_new
80+
values = values_new
81+
####################################################
82+
83+
84+
# Compute scaled dot-product attention (aka self-attention) with a causal mask
85+
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
86+
87+
####################################################
88+
# NEW
89+
K = attn_scores.size(-1)
90+
91+
if num_tokens == K:
92+
# No cache → use the pre‑baked triangular mask slice
93+
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
94+
else:
95+
# Cached: need to offset the diagonal by (K − num_tokens)
96+
offset = K - num_tokens # number of tokens already in cache before this chunk
97+
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
98+
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
99+
causal_mask = row_idx + offset < col_idx # True where j > i+offset
100+
####################################################
101+
102+
# Use the mask to fill attention scores
103+
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
104+
105+
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
106+
attn_weights = self.dropout(attn_weights)
107+
108+
# Shape: (b, num_tokens, num_heads, head_dim)
109+
context_vec = (attn_weights @ values).transpose(1, 2)
110+
111+
# Combine heads, where self.d_out = self.num_heads * self.head_dim
112+
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
113+
context_vec = self.out_proj(context_vec) # optional projection
114+
115+
return context_vec
116+
117+
####################################################
118+
# NEW
119+
def reset_cache(self):
120+
self.cache_k, self.cache_v = None, None
121+
####################################################
122+
123+
124+
#####################################
125+
# Chapter 4
126+
#####################################
127+
class LayerNorm(nn.Module):
128+
def __init__(self, emb_dim):
129+
super().__init__()
130+
self.eps = 1e-5
131+
self.scale = nn.Parameter(torch.ones(emb_dim))
132+
self.shift = nn.Parameter(torch.zeros(emb_dim))
133+
134+
def forward(self, x):
135+
mean = x.mean(dim=-1, keepdim=True)
136+
var = x.var(dim=-1, keepdim=True, unbiased=False)
137+
norm_x = (x - mean) / torch.sqrt(var + self.eps)
138+
return self.scale * norm_x + self.shift
139+
140+
141+
class GELU(nn.Module):
142+
def __init__(self):
143+
super().__init__()
144+
145+
def forward(self, x):
146+
return 0.5 * x * (1 + torch.tanh(
147+
torch.sqrt(torch.tensor(2.0 / torch.pi)) *
148+
(x + 0.044715 * torch.pow(x, 3))
149+
))
150+
151+
152+
class FeedForward(nn.Module):
153+
def __init__(self, cfg):
154+
super().__init__()
155+
self.layers = nn.Sequential(
156+
nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
157+
GELU(),
158+
nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"]),
159+
)
160+
161+
def forward(self, x):
162+
return self.layers(x)
163+
164+
165+
class TransformerBlock(nn.Module):
166+
def __init__(self, cfg):
167+
super().__init__()
168+
self.att = MultiHeadAttention(
169+
d_in=cfg["emb_dim"],
170+
d_out=cfg["emb_dim"],
171+
context_length=cfg["context_length"],
172+
num_heads=cfg["n_heads"],
173+
dropout=cfg["drop_rate"],
174+
qkv_bias=cfg["qkv_bias"],
175+
window_size=cfg["kv_window_size"]) # NEW
176+
self.ff = FeedForward(cfg)
177+
self.norm1 = LayerNorm(cfg["emb_dim"])
178+
self.norm2 = LayerNorm(cfg["emb_dim"])
179+
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
180+
181+
def forward(self, x, use_cache=False):
182+
# Shortcut connection for attention block
183+
shortcut = x
184+
x = self.norm1(x)
185+
186+
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
187+
####################################################
188+
# NEW
189+
x = self.att(x, use_cache=use_cache)
190+
####################################################
191+
192+
x = self.drop_shortcut(x)
193+
x = x + shortcut # Add the original input back
194+
195+
# Shortcut connection for feed-forward block
196+
shortcut = x
197+
x = self.norm2(x)
198+
x = self.ff(x)
199+
x = self.drop_shortcut(x)
200+
x = x + shortcut # Add the original input back
201+
202+
return x
203+
204+
205+
class GPTModel(nn.Module):
206+
def __init__(self, cfg):
207+
super().__init__()
208+
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
209+
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
210+
self.drop_emb = nn.Dropout(cfg["drop_rate"])
211+
212+
# self.trf_blocks = nn.Sequential(
213+
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
214+
####################################################
215+
# NEW
216+
self.trf_blocks = nn.ModuleList(
217+
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
218+
219+
self.current_pos = 0
220+
####################################################
221+
222+
self.final_norm = LayerNorm(cfg["emb_dim"])
223+
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
224+
225+
def forward(self, in_idx, use_cache=False):
226+
batch_size, seq_len = in_idx.shape
227+
tok_embeds = self.tok_emb(in_idx)
228+
229+
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
230+
231+
####################################################
232+
# NEW
233+
234+
if use_cache:
235+
pos_ids = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device, dtype=torch.long)
236+
self.current_pos += seq_len
237+
else:
238+
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
239+
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
240+
####################################################
241+
242+
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
243+
x = self.drop_emb(x)
244+
245+
# x = self.trf_blocks(x)
246+
####################################################
247+
# NEW
248+
for blk in self.trf_blocks:
249+
x = blk(x, use_cache=use_cache)
250+
####################################################
251+
252+
x = self.final_norm(x)
253+
logits = self.out_head(x)
254+
return logits
255+
256+
####################################################
257+
# NEW
258+
def reset_kv_cache(self):
259+
for blk in self.trf_blocks:
260+
blk.att.reset_cache()
261+
262+
####################################################
263+
264+
265+
def generate_text_simple(model, idx, max_new_tokens, context_size):
266+
# idx is (B, T) array of indices in the current context
267+
for _ in range(max_new_tokens):
268+
269+
# Crop current context if it exceeds the supported context size
270+
# E.g., if LLM supports only 5 tokens, and the context size is 10
271+
# then only the last 5 tokens are used as context
272+
idx_cond = idx[:, -context_size:]
273+
274+
# Get the predictions
275+
with torch.no_grad():
276+
logits = model(idx_cond)
277+
278+
# Focus only on the last time step
279+
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
280+
logits = logits[:, -1, :]
281+
282+
# Get the idx of the vocab entry with the highest logits value
283+
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
284+
285+
# Append sampled index to the running sequence
286+
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
287+
288+
return idx
289+
290+
291+
####################################################
292+
# NEW
293+
def generate_text_simple_cached(model, idx, max_new_tokens):
294+
model.eval()
295+
model.reset_kv_cache()
296+
297+
# Init cache with full prompt
298+
logits = model(idx, use_cache=True)
299+
300+
for _ in range(max_new_tokens):
301+
last_logits = logits[:, -1]
302+
next_idx = last_logits.argmax(dim=-1, keepdim=True)
303+
idx = torch.cat([idx, next_idx], dim=1)
304+
305+
logits = model(next_idx, use_cache=True)
306+
307+
return idx
308+
####################################################
309+
310+
311+
def main():
312+
GPT_CONFIG_124M = {
313+
"vocab_size": 50257, # Vocabulary size
314+
"context_length": 1024, # Context length
315+
"emb_dim": 768, # Embedding dimension
316+
"n_heads": 12, # Number of attention heads
317+
"n_layers": 12, # Number of layers
318+
"drop_rate": 0.1, # Dropout rate
319+
"qkv_bias": False, # Query-Key-Value bias
320+
"kv_window_size": 48 # NEW: KV cache window size
321+
}
322+
323+
torch.manual_seed(123)
324+
model = GPTModel(GPT_CONFIG_124M)
325+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
326+
model.to(device)
327+
model.eval() # disable dropout
328+
329+
start_context = "Hello, I am"
330+
331+
tokenizer = tiktoken.get_encoding("gpt2")
332+
encoded = tokenizer.encode(start_context)
333+
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
334+
335+
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
336+
print("\nInput text:", start_context)
337+
print("Encoded input text:", encoded)
338+
print("encoded_tensor.shape:", encoded_tensor.shape)
339+
340+
if torch.cuda.is_available():
341+
torch.cuda.synchronize()
342+
start = time.time()
343+
344+
# token_ids = generate_text_simple(
345+
# model=model,
346+
# idx=encoded_tensor,
347+
# max_new_tokens=200,
348+
# context_size=GPT_CONFIG_124M["context_length"]
349+
# )
350+
351+
####################################################
352+
# NEW
353+
token_ids = generate_text_simple_cached(
354+
model=model,
355+
idx=encoded_tensor,
356+
max_new_tokens=200,
357+
)
358+
####################################################
359+
360+
if torch.cuda.is_available():
361+
torch.cuda.synchronize()
362+
total_time = time.time() - start
363+
364+
decoded_text = tokenizer.decode(token_ids.squeeze(0).tolist())
365+
366+
print(f"\n\n{50*'='}\n{22*' '}OUT\n{50*'='}")
367+
print("\nOutput:", token_ids)
368+
print("Output length:", len(token_ids[0]))
369+
print("Output text:", decoded_text)
370+
371+
print(f"\nTime: {total_time:.2f} sec")
372+
print(f"{int(len(token_ids[0])/total_time)} tokens/sec")
373+
if torch.cuda.is_available():
374+
max_mem_bytes = torch.cuda.max_memory_allocated()
375+
max_mem_gb = max_mem_bytes / (1024 ** 3)
376+
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
377+
378+
379+
if __name__ == "__main__":
380+
main()

0 commit comments

Comments
 (0)
Please sign in to comment.