Skip to content

Llama3Fast #593

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 138 additions & 5 deletions pkg/llms_from_scratch/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ def __init__(self, cfg):
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

# Reusuable utilities
self.register_buffer("mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool())
self.register_buffer(
"mask", torch.triu(torch.ones(cfg["context_length"], cfg["context_length"]), diagonal=1).bool(),
persistent=False
)

if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
Expand All @@ -86,7 +89,6 @@ def __init__(self, cfg):
self.cfg = cfg

def forward(self, in_idx):
# Forward pass
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds

Expand Down Expand Up @@ -143,9 +145,7 @@ def forward(self, x):

class GroupedQueryAttention(nn.Module):
def __init__(
self, d_in, d_out, num_heads,
num_kv_groups,
dtype=None
self, d_in, d_out, num_heads, num_kv_groups, dtype=None
):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
Expand Down Expand Up @@ -375,3 +375,136 @@ def clean_text(text, header_end="assistant<|end_header_id|>\n\n"):
else:
# If the token is not found, return the original text
return text


######################################################################
# Llama 3 fast (alternative code geared towards efficiency)
######################################################################

class GroupedQueryAttentionFast(nn.Module):
"""
Drop-in replacement for GroupedQueryAttention but using PyTorch's
scaled_dot_product_attention, which uses FlashAttention if run
on an Ampere GPU (like A100) or newer and uses float16/bfloat16 or lower.
"""
def __init__(self, d_in, d_out, num_heads, num_kv_groups, dtype=None):
super().__init__()
assert d_out % num_heads == 0, "d_out must be divisible by num_heads"
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups"

self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.num_kv_groups = num_kv_groups
self.group_size = num_heads // num_kv_groups

self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype)
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype)
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype)

def forward(self, x, cos, sin):
b, num_tokens, _ = x.shape

# Project to queries, keys, values
q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)
k = self.W_key(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)
v = self.W_value(x).view(b, num_tokens, self.num_kv_groups, self.head_dim).transpose(1, 2)

# Apply Rotary Positional Embedding
q = apply_rope(q, cos, sin)
k = apply_rope(k, cos, sin)

# Expand key/value groups to full head count
k = k.repeat_interleave(self.group_size, dim=1)
v = v.repeat_interleave(self.group_size, dim=1)

# Efficient scaled dot-product attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
is_causal=True # Enables Flash/FlexAttention kernels
)

# Combine heads and project
attn_output = attn_output.transpose(1, 2).reshape(b, num_tokens, self.d_out)
return self.out_proj(attn_output)


class TransformerBlockFast(nn.Module):
"""
Same as original TransformerBlock but uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()
self.att = GroupedQueryAttentionFast(
d_in=cfg["emb_dim"],
d_out=cfg["emb_dim"],
num_heads=cfg["n_heads"],
num_kv_groups=cfg["n_kv_groups"],
dtype=cfg["dtype"]
)
self.ff = FeedForward(cfg)
self.norm1 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.norm2 = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])

def forward(self, x, cos, sin):
# Shortcut connection for attention block
shortcut = x
x = self.norm1(x)
x = self.att(x, cos, sin) # Shape [batch_size, num_tokens, emb_size]
x = x + shortcut # Add the original input back

# Shortcut connection for feed-forward block
shortcut = x
x = self.norm2(x)
x = self.ff(x)
x = x + shortcut # Add the original input back

return x


class Llama3ModelFast(nn.Module):
"""
Same as original Llama3Model but uses TransformerBlockFast
instead of TransformerBlock, which in turn uses
GroupedQueryAttentionFast instead of GroupedQueryAttention.
"""
def __init__(self, cfg):
super().__init__()

# Main model parameters
self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"], dtype=cfg["dtype"])

self.trf_blocks = nn.ModuleList( # ModuleList since Sequential can only accept one input, and we need `x, cos, sin`
[TransformerBlockFast(cfg) for _ in range(cfg["n_layers"])]
)

self.final_norm = nn.RMSNorm(cfg["emb_dim"], eps=1e-5, dtype=cfg["dtype"])
self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False, dtype=cfg["dtype"])

if cfg["orig_context_length"] != cfg["context_length"]:
cfg["rope_base"] = rescale_theta(
cfg["rope_base"],
cfg["orig_context_length"],
cfg["context_length"]
)
cos, sin = compute_rope_params(
head_dim=cfg["emb_dim"] // cfg["n_heads"],
theta_base=cfg["rope_base"],
context_length=cfg["context_length"],
freq_config=cfg["rope_freq"]
)
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
self.cfg = cfg

def forward(self, in_idx):
tok_embeds = self.tok_emb(in_idx)
x = tok_embeds

for block in self.trf_blocks:
x = block(x, self.cos, self.sin)
x = self.final_norm(x)
logits = self.out_head(x.to(self.cfg["dtype"]))
return logits
64 changes: 58 additions & 6 deletions pkg/llms_from_scratch/tests/test_llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
apply_rope,
rescale_theta,
LLAMA32_CONFIG_1B,
Llama3Model
GroupedQueryAttention,
GroupedQueryAttentionFast,
Llama3Model,
)

import importlib
Expand Down Expand Up @@ -117,13 +119,63 @@ def test_rescale():
assert old_theta == 500_000.


def test_grouped_query_attention_equivalence():
torch.manual_seed(42)
b, t, d_in, d_out, num_heads, num_kv_groups = 2, 8, 32, 64, 4, 2

x = torch.randn(b, t, d_in)
cos, sin = compute_rope_params(
head_dim=d_out // num_heads,
theta_base=50_000,
context_length=t,
freq_config={
"factor": 32.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_context_length": t,
}
)

# Causal mask for the slow version
mask = torch.triu(torch.ones(t, t, dtype=torch.bool), diagonal=1)

attn1 = GroupedQueryAttention(d_in, d_out, num_heads, num_kv_groups)
attn2 = GroupedQueryAttentionFast(d_in, d_out, num_heads, num_kv_groups)

# Copy weights to make both models identical
attn2.load_state_dict(attn1.state_dict())

# Run both
y1 = attn1(x, mask, cos, sin)
y2 = attn2(x, cos, sin)

# Compare outputs
max_diff = (y1 - y2).abs().max().item()
print(f"Max difference between slow and fast outputs: {max_diff:.4e}")
assert torch.allclose(y1, y2, atol=1e-4)


@pytest.fixture(scope="session")
def llama3_weights_path(tmp_path_factory):
"""Creates and saves a deterministic Llama3 model for testing."""
path = tmp_path_factory.mktemp("models") / "llama3_test_weights.pt"

if not path.exists():
torch.manual_seed(123)
model = Llama3Model(LLAMA32_CONFIG_1B)
torch.save(model.state_dict(), path)

return path


@pytest.mark.parametrize("ModelClass", [Llama3Model])
def test_gpt_model_variants(ModelClass):
def test_gpt_model_variants(ModelClass, llama3_weights_path):
torch.manual_seed(123)
model = ModelClass(LLAMA32_CONFIG_1B)
model.load_state_dict(torch.load(llama3_weights_path))
model.eval()

start_context = "Hello, I am"
start_context = "Llamas eat"

tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
Expand All @@ -137,11 +189,11 @@ def test_gpt_model_variants(ModelClass):
out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=10,
max_new_tokens=5,
context_size=LLAMA32_CONFIG_1B["context_length"]
)
print("Encoded output text:", out)
expect = torch.tensor([
[15496, 11, 314, 716, 78563, 89362, 19616, 115725, 114917,
97198, 60342, 19108, 100752, 98969]
[43, 2543, 292, 4483, 100383, 8113, 21197, 33804, 54419]
])
assert torch.equal(expect, out)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llms-from-scratch"
version = "1.0.5"
version = "1.0.6"
description = "Implement a ChatGPT-like LLM in PyTorch from scratch, step by step"
readme = "README.md"
requires-python = ">=3.10"
Expand Down