Skip to content

Commit 81eda38

Browse files
authored
Improve KV cache code for torch.compile (#705)
* Improve KV cache code for torch.compile * cleanup * cleanup
1 parent 6522be9 commit 81eda38

File tree

8 files changed

+595
-317
lines changed

8 files changed

+595
-317
lines changed

ch04/03_kv-cache/gpt_ch04.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
2727
self.dropout = nn.Dropout(dropout)
2828
self.register_buffer(
2929
"mask",
30-
torch.triu(torch.ones(context_length, context_length),diagonal=1),
30+
torch.triu(torch.ones(context_length, context_length), diagonal=1),
3131
persistent=False
3232
)
3333

ch05/07_gpt_to_llama/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,14 +236,14 @@ token_ids = generate_text_simple(
236236
)
237237
```
238238

239-
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
239+
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
240240

241241
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
242-
|-------------|-------------------|-----------------|------------|-------------------|
242+
| ----------- | ----------------- | --------------- | ---------- | ----------------- |
243243
| Llama3Model | Regular | Mac Mini M4 CPU | 1 | - |
244244
| Llama3Model | Regular compiled | Mac Mini M4 CPU | - | - |
245-
| Llama3Model | KV cache | Mac Mini M4 CPU | 62 | - |
246-
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | - | - |
245+
| Llama3Model | KV cache | Mac Mini M4 CPU | 68 | - |
246+
| Llama3Model | KV cache compiled | Mac Mini M4 CPU | 86 | - |
247247
| | | | | |
248248
| Llama3Model | Regular | Mac Mini M4 GPU | 15 | - |
249249
| Llama3Model | Regular compiled | Mac Mini M4 GPU | - | - |
@@ -252,7 +252,7 @@ Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is
252252
| | | | | |
253253
| Llama3Model | Regular | Nvidia A100 GPU | 42 | 2.91 GB |
254254
| Llama3Model | Regular compiled | Nvidia A100 GPU | 170 | 3.12 GB |
255-
| Llama3Model | KV cache | Nvidia A100 GPU | 60 | 18.87 GB |
256-
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 59 | 19.12 GB |
255+
| Llama3Model | KV cache | Nvidia A100 GPU | 58 | 2.87 GB |
256+
| Llama3Model | KV cache compiled | Nvidia A100 GPU | 161 | 3.61 GB |
257257

258258
Note that all settings above have been tested to produce the same text outputs.

ch05/11_qwen3/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,23 @@ token_ids = generate_text_simple(
209209
)
210210
```
211211

212-
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage dominates here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements).
212+
Note that the peak memory usage is only listed for Nvidia CUDA devices, as it is easier to calculate. However, the memory usage on other devices is likely similar as it uses a similar precision format, and the KV cache storage results in even lower memory usage here for the generated 150-token text (however, different devices may implement matrix multiplication differently and may result in different peak memory requirements; and KV-cache memory may increase prohibitively for longer contexts lengths).
213213

214214
| Model | Mode | Hardware | Tokens/sec | GPU Memory (VRAM) |
215215
| ---------- | ----------------- | --------------- | ---------- | ----------------- |
216216
| Qwen3Model | Regular | Mac Mini M4 CPU | 1 | - |
217217
| Qwen3Model | Regular compiled | Mac Mini M4 CPU | 1 | - |
218218
| Qwen3Model | KV cache | Mac Mini M4 CPU | 80 | - |
219-
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 82 | - |
219+
| Qwen3Model | KV cache compiled | Mac Mini M4 CPU | 137 | - |
220220
| | | | | |
221221
| Qwen3Model | Regular | Mac Mini M4 GPU | 21 | - |
222222
| Qwen3Model | Regular compiled | Mac Mini M4 GPU | Error | - |
223-
| Qwen3Model | KV cache | Mac Mini M4 GPU | 32 | - |
223+
| Qwen3Model | KV cache | Mac Mini M4 GPU | 28 | - |
224224
| Qwen3Model | KV cache compiled | Mac Mini M4 GPU | Error | - |
225225
| | | | | |
226-
| Qwen3Model | Regular | Nvidia A100 GPU | 25 | 1.49 GB |
226+
| Qwen3Model | Regular | Nvidia A100 GPU | 26 | 1.49 GB |
227227
| Qwen3Model | Regular compiled | Nvidia A100 GPU | 107 | 1.99 GB |
228-
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 10.20 GB |
229-
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 24 | 10.61 GB |
228+
| Qwen3Model | KV cache | Nvidia A100 GPU | 25 | 1.47 GB |
229+
| Qwen3Model | KV cache compiled | Nvidia A100 GPU | 90 | 1.48 GB |
230230

231231
Note that all settings above have been tested to produce the same text outputs.

pkg/llms_from_scratch/kv_cache/generate.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,24 @@
33
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
44
# Code: https://github.com/rasbt/LLMs-from-scratch
55

6+
from .utils import KVCache
67
import torch
78

89

910
def generate_text_simple(model, idx, max_new_tokens, context_size=None, use_cache=True):
1011
model.eval()
11-
1212
ctx_len = context_size or model.cfg["context_length"]
13+
cache = KVCache(n_layers=model.cfg["n_layers"]) if use_cache else None
1314

1415
with torch.no_grad():
1516
if use_cache:
1617
model.reset_kv_cache()
17-
logits = model(idx[:, -ctx_len:], use_cache=True)
18+
logits = model(idx[:, -ctx_len:], use_cache=True, cache=cache)
1819

1920
for _ in range(max_new_tokens):
2021
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
2122
idx = torch.cat([idx, next_idx], dim=1)
22-
logits = model(next_idx, use_cache=True)
23+
logits = model(next_idx, use_cache=True, cache=cache)
2324
else:
2425
for _ in range(max_new_tokens):
2526
logits = model(idx[:, -ctx_len:], use_cache=False)

pkg/llms_from_scratch/kv_cache/gpt2.py

Lines changed: 43 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
44
# Code: https://github.com/rasbt/LLMs-from-scratch
55

6+
from .utils import KVCache # noqa: F401
7+
68
import torch
79
import torch.nn as nn
810

@@ -11,7 +13,7 @@
1113
# Chapter 3
1214
#####################################
1315
class MultiHeadAttention(nn.Module):
14-
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False, max_seq_len=None, window_size=None):
16+
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
1517
super().__init__()
1618
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
1719

@@ -25,80 +27,41 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
2527
self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
2628
self.dropout = nn.Dropout(dropout)
2729

28-
####################################################
29-
# NEW
30-
self.max_seq_len = max_seq_len or context_length
31-
self.window_size = window_size or self.max_seq_len
32-
self.register_buffer("cache_k", None, persistent=False)
33-
self.register_buffer("cache_v", None, persistent=False)
34-
####################################################
35-
36-
def forward(self, x, use_cache=False):
30+
def forward(self, x, use_cache=False, start_pos=0, cache=None):
3731
b, num_tokens, d_in = x.shape
3832

39-
keys_new = self.W_key(x) # Shape: (b, num_tokens, d_out)
40-
values_new = self.W_value(x)
33+
keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
34+
values = self.W_value(x)
4135
queries = self.W_query(x)
4236

4337
# We implicitly split the matrix by adding a `num_heads` dimension
4438
# Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
45-
keys_new = keys_new.view(b, num_tokens, self.num_heads, self.head_dim)
46-
values_new = values_new.view(b, num_tokens, self.num_heads, self.head_dim)
39+
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
40+
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
4741
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
4842

4943
# Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
50-
keys_new = keys_new.transpose(1, 2)
51-
values_new = values_new.transpose(1, 2)
44+
keys = keys.transpose(1, 2)
5245
queries = queries.transpose(1, 2)
46+
values = values.transpose(1, 2)
5347

54-
####################################################
55-
# NEW
5648
if use_cache:
57-
if self.cache_k is None or self.cache_k.size(0) != b:
58-
self.cache_k = torch.zeros(b, self.num_heads,
59-
self.window_size, self.head_dim,
60-
device=x.device)
61-
self.cache_v = torch.zeros_like(self.cache_k)
62-
self.ptr_cur = 0 # pointer to next free slot
63-
64-
# if incoming chunk would overflow discard oldest tokens
65-
if self.ptr_cur + num_tokens > self.window_size:
66-
overflow = self.ptr_cur + num_tokens - self.window_size
67-
# shift everything left by `overflow` (cheap view-copy)
68-
self.cache_k[:, :, :-overflow, :] = self.cache_k[:, :, overflow:, :].clone()
69-
self.cache_v[:, :, :-overflow, :] = self.cache_v[:, :, overflow:, :].clone()
70-
self.ptr_cur -= overflow # pointer after shift
71-
72-
self.cache_k[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = keys_new
73-
self.cache_v[:, :, self.ptr_cur:self.ptr_cur + num_tokens, :] = values_new
74-
self.ptr_cur += num_tokens
75-
76-
keys = self.cache_k[:, :, :self.ptr_cur, :]
77-
values = self.cache_v[:, :, :self.ptr_cur, :]
49+
if cache is not None:
50+
keys = torch.cat([cache[0], keys], dim=2)
51+
values = torch.cat([cache[1], values], dim=2)
52+
next_cache = (keys, values)
7853
else:
79-
keys, values = keys_new, values_new
80-
self.ptr_cur = 0 # keep pointer sane if you interleave modes
81-
####################################################
82-
# Compute scaled dot-product attention (aka self-attention) with a causal mask
83-
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
54+
next_cache = None
8455

85-
####################################################
86-
# NEW
87-
K = attn_scores.size(-1)
56+
seq_len = keys.size(2)
57+
causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=x.device), diagonal=1)
58+
causal_mask = causal_mask[:, -num_tokens:][None, None, :, :]
8859

89-
if num_tokens == K:
90-
# No cache → use the pre‑baked triangular mask slice
91-
causal_mask = torch.triu(torch.ones(num_tokens, K, device=x.device, dtype=torch.bool), diagonal=1)
92-
else:
93-
# Cached: need to offset the diagonal by (K − num_tokens)
94-
offset = K - num_tokens # number of tokens already in cache before this chunk
95-
row_idx = torch.arange(num_tokens, device=x.device).unsqueeze(1) # (num_tokens, 1)
96-
col_idx = torch.arange(K, device=x.device).unsqueeze(0) # (1, K)
97-
causal_mask = row_idx + offset < col_idx # True where j > i+offset
98-
####################################################
60+
# Compute scaled dot-product attention (aka self-attention) with a causal mask
61+
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
9962

10063
# Use the mask to fill attention scores
101-
attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), -torch.inf)
64+
attn_scores.masked_fill_(causal_mask, -torch.inf)
10265

10366
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
10467
attn_weights = self.dropout(attn_weights)
@@ -110,13 +73,7 @@ def forward(self, x, use_cache=False):
11073
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
11174
context_vec = self.out_proj(context_vec) # optional projection
11275

113-
return context_vec
114-
115-
####################################################
116-
# NEW
117-
def reset_cache(self):
118-
self.cache_k, self.cache_v = None, None
119-
####################################################
76+
return context_vec, next_cache
12077

12178

12279
#####################################
@@ -169,25 +126,17 @@ def __init__(self, cfg):
169126
context_length=cfg["context_length"],
170127
num_heads=cfg["n_heads"],
171128
dropout=cfg["drop_rate"],
172-
qkv_bias=cfg["qkv_bias"],
173-
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
174-
)
129+
qkv_bias=cfg["qkv_bias"])
175130
self.ff = FeedForward(cfg)
176131
self.norm1 = LayerNorm(cfg["emb_dim"])
177132
self.norm2 = LayerNorm(cfg["emb_dim"])
178133
self.drop_shortcut = nn.Dropout(cfg["drop_rate"])
179134

180-
def forward(self, x, use_cache=False):
135+
def forward(self, x, use_cache=False, start_pos=0, cache=None):
181136
# Shortcut connection for attention block
182137
shortcut = x
183138
x = self.norm1(x)
184-
185-
# x = self.att(x) # Shape [batch_size, num_tokens, emb_size]
186-
####################################################
187-
# NEW
188-
x = self.att(x, use_cache=use_cache)
189-
####################################################
190-
139+
x, next_cache = self.att(x, use_cache=use_cache, start_pos=start_pos, cache=cache) # Shape [batch_size, num_tokens, emb_size]
191140
x = self.drop_shortcut(x)
192141
x = x + shortcut # Add the original input back
193142

@@ -198,7 +147,7 @@ def forward(self, x, use_cache=False):
198147
x = self.drop_shortcut(x)
199148
x = x + shortcut # Add the original input back
200149

201-
return x
150+
return x, next_cache
202151

203152

204153
class GPTModel(nn.Module):
@@ -208,80 +157,34 @@ def __init__(self, cfg):
208157
self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
209158
self.drop_emb = nn.Dropout(cfg["drop_rate"])
210159

211-
# self.trf_blocks = nn.Sequential(
212-
# *[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
213-
####################################################
214-
# NEW
215-
self.trf_blocks = nn.ModuleList(
216-
[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
217-
218-
self.ptr_current_pos = 0
219-
####################################################
160+
self.trf_blocks = nn.Sequential(
161+
*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])])
220162

221163
self.final_norm = LayerNorm(cfg["emb_dim"])
222164
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False)
165+
self.current_pos = 0
223166

224-
def forward(self, in_idx, use_cache=False):
167+
def forward(self, in_idx, use_cache=False, cache=None):
225168
batch_size, seq_len = in_idx.shape
169+
pos = torch.arange(self.current_pos, self.current_pos + seq_len, device=in_idx.device)
226170
tok_embeds = self.tok_emb(in_idx)
227-
228-
# pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
229-
230-
####################################################
231-
# NEW
171+
pos_embeds = self.pos_emb(pos)
172+
x = self.drop_emb(tok_embeds + pos_embeds)
232173

233174
if use_cache:
234-
pos_ids = torch.arange(self.ptr_current_pos, self.ptr_current_pos + seq_len, device=in_idx.device, dtype=torch.long)
235-
self.ptr_current_pos += seq_len
175+
start_pos = self.current_pos
176+
self.current_pos += seq_len
236177
else:
237-
pos_ids = torch.arange(0, seq_len, device=in_idx.device, dtype=torch.long)
238-
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
239-
####################################################
240-
241-
x = tok_embeds + pos_embeds # Shape [batch_size, num_tokens, emb_size]
242-
x = self.drop_emb(x)
178+
start_pos = 0
243179

244-
# x = self.trf_blocks(x)
245-
####################################################
246-
# NEW
247-
for blk in self.trf_blocks:
248-
x = blk(x, use_cache=use_cache)
249-
####################################################
180+
next_cache = []
181+
for i, block in enumerate(self.trf_blocks):
182+
blk_cache = cache.get(i) if cache else None
183+
x, new_cache = block(x, use_cache=use_cache, start_pos=start_pos, cache=blk_cache)
184+
if cache:
185+
cache.update(i, new_cache)
186+
next_cache.append(new_cache)
250187

251188
x = self.final_norm(x)
252189
logits = self.out_head(x)
253190
return logits
254-
255-
####################################################
256-
# NEW
257-
def reset_kv_cache(self):
258-
for blk in self.trf_blocks:
259-
blk.att.reset_cache()
260-
self.ptr_current_pos = 0
261-
####################################################
262-
263-
264-
def generate_text_simple(model, idx, max_new_tokens, context_size):
265-
# idx is (B, T) array of indices in the current context
266-
for _ in range(max_new_tokens):
267-
268-
# Crop current context if it exceeds the supported context size
269-
# E.g., if LLM supports only 5 tokens, and the context size is 10
270-
# then only the last 5 tokens are used as context
271-
idx_cond = idx[:, -context_size:]
272-
273-
# Get the predictions
274-
with torch.no_grad():
275-
logits = model(idx_cond)
276-
277-
# Focus only on the last time step
278-
# (batch, n_token, vocab_size) becomes (batch, vocab_size)
279-
logits = logits[:, -1, :]
280-
281-
# Get the idx of the vocab entry with the highest logits value
282-
idx_next = torch.argmax(logits, dim=-1, keepdim=True) # (batch, 1)
283-
284-
# Append sampled index to the running sequence
285-
idx = torch.cat((idx, idx_next), dim=1) # (batch, n_tokens+1)
286-
287-
return idx

0 commit comments

Comments
 (0)