Skip to content

Commit 6522be9

Browse files
martinzwmrasbt
andauthored
Fix bug in masking when kv cache is used. (#697)
* Fix bug in masking when kv cache is used. * add tests * dd tests * upd * add kv cache test to gh workflow * explicit mask slicing * upd --------- Co-authored-by: rasbt <[email protected]>
1 parent 37b26c2 commit 6522be9

File tree

5 files changed

+179
-57
lines changed

5 files changed

+179
-57
lines changed

.github/workflows/basic-tests-linux-uv.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ jobs:
4949
source .venv/bin/activate
5050
pytest --ruff setup/02_installing-python-libraries/tests.py
5151
pytest --ruff ch04/01_main-chapter-code/tests.py
52+
pytest --ruff ch04/03_kv-cache/tests.py
5253
pytest --ruff ch05/01_main-chapter-code/tests.py
5354
pytest --ruff ch05/07_gpt_to_llama/tests/tests.py
5455
pytest --ruff ch06/01_main-chapter-code/tests.py

ch04/03_kv-cache/README.md

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,18 @@ def forward(self, x, use_cache=False):
8686
keys, values = self.cache_k, self.cache_v
8787
else:
8888
keys, values = keys_new, values_new
89+
90+
# ...
91+
92+
num_tokens_Q = queries.shape[-2]
93+
num_tokens_K = keys.shape[-2]
94+
if use_cache:
95+
mask_bool = self.mask.bool()[
96+
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
97+
]
98+
self.ptr_current_pos += num_tokens_Q
99+
else:
100+
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
89101
```
90102

91103
&nbsp;
@@ -98,6 +110,7 @@ When generating texts, between independent sequences (for instance to text gener
98110
```python
99111
def reset_cache(self):
100112
self.cache_k, self.cache_v = None, None
113+
self.ptr_current_pos = 0
101114
```
102115

103116
&nbsp;
@@ -157,30 +170,29 @@ def reset_kv_cache(self):
157170
With the changes to the `GPTModel`, `TransformerBlock`, and `MultiHeadAttention`, finally, here's how we use the KV cache in a simple text generation function:
158171

159172
```python
160-
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
173+
def generate_text_simple_cached(model, idx, max_new_tokens,
174+
context_size=None, use_cache=True):
161175
model.eval()
176+
ctx_len = context_size or model.pos_emb.num_embeddings
162177

163-
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
164-
if use_cache:
165-
# Init cache with full prompt
166-
model.reset_kv_cache()
167-
with torch.no_grad():
178+
with torch.no_grad():
179+
if use_cache:
180+
# Init cache with full prompt
181+
model.reset_kv_cache()
168182
logits = model(idx[:, -ctx_len:], use_cache=True)
169183

170-
for _ in range(max_new_tokens):
171-
# a) pick the token with the highest log-probability (greedy sampling)
172-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
173-
# b) append it to the running sequence
174-
idx = torch.cat([idx, next_idx], dim=1)
175-
# c) feed model only the new token
176-
with torch.no_grad():
184+
for _ in range(max_new_tokens):
185+
# a) pick the token with the highest log-probability (greedy sampling)
186+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
187+
# b) append it to the running sequence
188+
idx = torch.cat([idx, next_idx], dim=1)
189+
# c) feed model only the new token
177190
logits = model(next_idx, use_cache=True)
178-
else:
179-
for _ in range(max_new_tokens):
180-
with torch.no_grad():
191+
else:
192+
for _ in range(max_new_tokens):
181193
logits = model(idx[:, -ctx_len:], use_cache=False)
182-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
183-
idx = torch.cat([idx, next_idx], dim=1)
194+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
195+
idx = torch.cat([idx, next_idx], dim=1)
184196

185197
return idx
186198
```

ch04/03_kv-cache/gpt_with_kv_cache.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,15 @@ def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=Fal
2727
self.dropout = nn.Dropout(dropout)
2828
self.register_buffer(
2929
"mask",
30-
torch.triu(torch.ones(context_length, context_length),diagonal=1),
30+
torch.triu(torch.ones(context_length, context_length), diagonal=1),
3131
persistent=False
3232
)
3333

3434
####################################################
3535
# NEW
3636
self.register_buffer("cache_k", None, persistent=False)
3737
self.register_buffer("cache_v", None, persistent=False)
38+
self.ptr_current_pos = 0
3839
####################################################
3940

4041
def forward(self, x, use_cache=False):
@@ -71,8 +72,19 @@ def forward(self, x, use_cache=False):
7172
# Compute scaled dot-product attention (aka self-attention) with a causal mask
7273
attn_scores = queries @ keys.transpose(2, 3) # Dot product for each head
7374

75+
####################################################
76+
# NEW
77+
num_tokens_Q = queries.shape[-2]
78+
num_tokens_K = keys.shape[-2]
79+
if use_cache:
80+
mask_bool = self.mask.bool()[
81+
self.ptr_current_pos:self.ptr_current_pos + num_tokens_Q, :num_tokens_K
82+
]
83+
self.ptr_current_pos += num_tokens_Q
84+
####################################################
7485
# Original mask truncated to the number of tokens and converted to boolean
75-
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
86+
else:
87+
mask_bool = self.mask.bool()[:num_tokens_Q, :num_tokens_K]
7688

7789
# Use the mask to fill attention scores
7890
attn_scores.masked_fill_(mask_bool, -torch.inf)
@@ -93,6 +105,7 @@ def forward(self, x, use_cache=False):
93105
# NEW
94106
def reset_cache(self):
95107
self.cache_k, self.cache_v = None, None
108+
self.ptr_current_pos = 0
96109
####################################################
97110

98111

@@ -264,30 +277,29 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
264277

265278
####################################################
266279
# NEW
267-
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
280+
def generate_text_simple_cached(model, idx, max_new_tokens,
281+
context_size=None, use_cache=True):
268282
model.eval()
283+
ctx_len = context_size or model.pos_emb.num_embeddings
269284

270-
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
271-
if use_cache:
272-
# Init cache with full prompt
273-
model.reset_kv_cache()
274-
with torch.no_grad():
285+
with torch.no_grad():
286+
if use_cache:
287+
# Init cache with full prompt
288+
model.reset_kv_cache()
275289
logits = model(idx[:, -ctx_len:], use_cache=True)
276290

277-
for _ in range(max_new_tokens):
278-
# a) pick the token with the highest log-probability (greedy sampling)
279-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
280-
# b) append it to the running sequence
281-
idx = torch.cat([idx, next_idx], dim=1)
282-
# c) feed model only the new token
283-
with torch.no_grad():
291+
for _ in range(max_new_tokens):
292+
# a) pick the token with the highest log-probability (greedy sampling)
293+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
294+
# b) append it to the running sequence
295+
idx = torch.cat([idx, next_idx], dim=1)
296+
# c) feed model only the new token
284297
logits = model(next_idx, use_cache=True)
285-
else:
286-
for _ in range(max_new_tokens):
287-
with torch.no_grad():
298+
else:
299+
for _ in range(max_new_tokens):
288300
logits = model(idx[:, -ctx_len:], use_cache=False)
289-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
290-
idx = torch.cat([idx, next_idx], dim=1)
301+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
302+
idx = torch.cat([idx, next_idx], dim=1)
291303

292304
return idx
293305
####################################################

ch04/03_kv-cache/gpt_with_kv_cache_optimized.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ def __init__(self, cfg):
171171
num_heads=cfg["n_heads"],
172172
dropout=cfg["drop_rate"],
173173
qkv_bias=cfg["qkv_bias"],
174-
window_size=cfg["kv_window_size"]) # NEW
174+
window_size=cfg["kv_window_size"] if "kv_window_size" in cfg else cfg["context_length"] # NEW
175+
)
175176
self.ff = FeedForward(cfg)
176177
self.norm1 = LayerNorm(cfg["emb_dim"])
177178
self.norm2 = LayerNorm(cfg["emb_dim"])
@@ -289,30 +290,25 @@ def generate_text_simple(model, idx, max_new_tokens, context_size):
289290

290291
####################################################
291292
# NEW
292-
def generate_text_simple_cached(model, idx, max_new_tokens, use_cache=True):
293+
def generate_text_simple_cached(model, idx, max_new_tokens, context_size=None, use_cache=True):
293294
model.eval()
294295

295-
ctx_len = model.pos_emb.num_embeddings # max supported length, e.g. 1024
296-
if use_cache:
297-
# Init cache with full prompt
298-
model.reset_kv_cache()
299-
with torch.no_grad():
296+
ctx_len = context_size or model.pos_emb.num_embeddings
297+
298+
with torch.no_grad():
299+
if use_cache:
300+
model.reset_kv_cache()
300301
logits = model(idx[:, -ctx_len:], use_cache=True)
301302

302-
for _ in range(max_new_tokens):
303-
# a) pick the token with the highest log-probability (greedy sampling)
304-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
305-
# b) append it to the running sequence
306-
idx = torch.cat([idx, next_idx], dim=1)
307-
# c) feed model only the new token
308-
with torch.no_grad():
303+
for _ in range(max_new_tokens):
304+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
305+
idx = torch.cat([idx, next_idx], dim=1)
309306
logits = model(next_idx, use_cache=True)
310-
else:
311-
for _ in range(max_new_tokens):
312-
with torch.no_grad():
307+
else:
308+
for _ in range(max_new_tokens):
313309
logits = model(idx[:, -ctx_len:], use_cache=False)
314-
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
315-
idx = torch.cat([idx, next_idx], dim=1)
310+
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
311+
idx = torch.cat([idx, next_idx], dim=1)
316312

317313
return idx
318314
####################################################

ch04/03_kv-cache/tests.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# Code to test the GPT model implementation against the KV cache variants
2+
3+
import pytest
4+
import torch
5+
import tiktoken
6+
7+
from gpt_ch04 import GPTModel as GPTModelBase
8+
from gpt_ch04 import generate_text_simple
9+
10+
from gpt_with_kv_cache import GPTModel as GPTModelKV1
11+
from gpt_with_kv_cache_optimized import GPTModel as GPTModelKV2
12+
from gpt_with_kv_cache import generate_text_simple_cached
13+
14+
15+
GPT_CONFIG_124M = {
16+
"vocab_size": 50257,
17+
"context_length": 1024,
18+
"emb_dim": 768,
19+
"n_heads": 12,
20+
"n_layers": 12,
21+
"drop_rate": 0.1,
22+
"qkv_bias": False,
23+
}
24+
25+
26+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27+
28+
29+
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
30+
def test_gpt_model_equivalence_not_cached(ModelClass):
31+
torch.manual_seed(123)
32+
33+
model = ModelClass(GPT_CONFIG_124M).to(device)
34+
model.eval()
35+
36+
tokenizer = tiktoken.get_encoding("gpt2")
37+
prompt = "Hello, I am"
38+
encoded = tokenizer.encode(prompt)
39+
encoded_tensor = torch.tensor(encoded, device=device).unsqueeze(0)
40+
41+
model_name = ModelClass.__module__ + "." + ModelClass.__name__
42+
43+
token_ids = generate_text_simple(
44+
model=model,
45+
idx=encoded_tensor,
46+
max_new_tokens=30,
47+
context_size=GPT_CONFIG_124M["context_length"]
48+
)
49+
50+
if not hasattr(test_gpt_model_equivalence_not_cached, "results"):
51+
test_gpt_model_equivalence_not_cached.results = []
52+
53+
test_gpt_model_equivalence_not_cached.results.append((model_name, token_ids))
54+
55+
if len(test_gpt_model_equivalence_not_cached.results) == 3:
56+
base_name, base_output = test_gpt_model_equivalence_not_cached.results[0]
57+
for other_name, other_output in test_gpt_model_equivalence_not_cached.results[1:]:
58+
assert torch.equal(base_output, other_output), (
59+
f"Mismatch between {base_name} and {other_name}"
60+
)
61+
62+
63+
@pytest.mark.parametrize("ModelClass", [GPTModelBase, GPTModelKV1, GPTModelKV2])
64+
def test_gpt_model_equivalence_cached(ModelClass):
65+
torch.manual_seed(123)
66+
67+
model = ModelClass(GPT_CONFIG_124M).to(device)
68+
model.eval()
69+
70+
tokenizer = tiktoken.get_encoding("gpt2")
71+
prompt = "Hello, I am"
72+
encoded_tensor = torch.tensor(tokenizer.encode(prompt), device=device).unsqueeze(0)
73+
74+
model_name = ModelClass.__module__ + "." + ModelClass.__name__
75+
76+
if ModelClass is GPTModelBase:
77+
token_ids = generate_text_simple(
78+
model=model,
79+
idx=encoded_tensor,
80+
max_new_tokens=30,
81+
context_size=GPT_CONFIG_124M["context_length"]
82+
)
83+
else:
84+
token_ids = generate_text_simple_cached(
85+
model=model,
86+
idx=encoded_tensor,
87+
max_new_tokens=30,
88+
context_size=GPT_CONFIG_124M["context_length"]
89+
)
90+
91+
if not hasattr(test_gpt_model_equivalence_cached, "results"):
92+
test_gpt_model_equivalence_cached.results = []
93+
94+
test_gpt_model_equivalence_cached.results.append((model_name, token_ids))
95+
96+
if len(test_gpt_model_equivalence_cached.results) == 3:
97+
base_name, base_output = test_gpt_model_equivalence_cached.results[0]
98+
for other_name, other_output in test_gpt_model_equivalence_cached.results[1:]:
99+
assert torch.equal(base_output, other_output), (
100+
f"Mismatch between {base_name} and {other_name}"
101+
)

0 commit comments

Comments
 (0)