Skip to content

Commit aedad7e

Browse files
authored
Add Llama 3.2 to pkg (#591)
* Add Llama 3.2 to pkg * remove redundant attributes * update tests * updates * updates * updates * fix link * fix link
1 parent 152a087 commit aedad7e

File tree

7 files changed

+719
-6
lines changed

7 files changed

+719
-6
lines changed

.github/workflows/basic-tests-linux-uv.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,5 @@ jobs:
7171
shell: bash
7272
run: |
7373
source .venv/bin/activate
74+
uv pip install transformers
7475
pytest pkg/llms_from_scratch/tests/

.github/workflows/check-links.yml

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ jobs:
2424
run: |
2525
curl -LsSf https://astral.sh/uv/install.sh | sh
2626
uv add pytest-ruff pytest-check-links
27-
# Current version of retry doesn't work well if there are broken non-URL links
28-
# pip install pytest pytest-check-links pytest-retry
2927
3028
- name: Check links
3129
run: |
@@ -40,5 +38,3 @@ jobs:
4038
--check-links-ignore "https://arxiv.org/*" \
4139
--check-links-ignore "https://ai.stanford.edu/~amaas/data/sentiment/" \
4240
--check-links-ignore "https://x.com/*"
43-
# pytest --check-links ./ --check-links-ignore "https://platform.openai.com/*" --check-links-ignore "https://arena.lmsys.org" --retries 2 --retry-delay 5
44-

ch05/07_gpt_to_llama/README.md

Lines changed: 185 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,188 @@ This folder contains code for converting the GPT implementation from chapter 4 a
88
- [converting-llama2-to-llama3.ipynb](converting-llama2-to-llama3.ipynb): contains code to convert the Llama 2 model to Llama 3, Llama 3.1, and Llama 3.2
99
- [standalone-llama32.ipynb](standalone-llama32.ipynb): a standalone notebook implementing Llama 3.2
1010

11-
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt-and-all-llamas.webp">
11+
<img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/bonus/gpt-to-llama/gpt-and-all-llamas.webp">
12+
13+
14+
&nbsp;
15+
### Using Llama 3.2 via the `llms-from-scratch` package
16+
17+
For an easy way to use the Llama 3.2 1B and 3B models, you can also use the `llms-from-scratch` PyPI package based on the source code in this repository at [pkg/llms_from_scratch](../../pkg/llms_from_scratch).
18+
19+
&nbsp;
20+
##### 1) Installation
21+
22+
```bash
23+
pip install llms_from_scratch blobfile
24+
```
25+
&nbsp;
26+
##### 2) Model and text generation settings
27+
28+
Specify which model to use:
29+
30+
```python
31+
MODEL_FILE = "llama3.2-1B-instruct.pth"
32+
# MODEL_FILE = "llama3.2-1B-base.pth"
33+
# MODEL_FILE = "llama3.2-3B-instruct.pth"
34+
# MODEL_FILE = "llama3.2-3B-base.pth"
35+
```
36+
37+
Basic text generation settings that can be defined by the user. Note that the recommended 8192-token context size requires approximately 3 GB of VRAM for the text generation example.
38+
39+
```python
40+
MODEL_CONTEXT_LENGTH = 8192 # Supports up to 131_072
41+
42+
# Text generation settings
43+
if "instruct" in MODEL_FILE:
44+
PROMPT = "What do llamas eat?"
45+
else:
46+
PROMPT = "Llamas eat"
47+
48+
MAX_NEW_TOKENS = 150
49+
TEMPERATURE = 0.
50+
TOP_K = 1
51+
```
52+
53+
&nbsp;
54+
##### 3) Weight download and loading
55+
56+
This automatically downloads the weight file based on the model choice above:
57+
58+
```python
59+
import os
60+
import urllib.request
61+
62+
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}"
63+
64+
if not os.path.exists(MODEL_FILE):
65+
urllib.request.urlretrieve(url, MODEL_FILE)
66+
print(f"Downloaded to {MODEL_FILE}")
67+
```
68+
69+
The model weights are then loaded as follows:
70+
71+
```python
72+
import torch
73+
from llms_from_scratch.llama3 import Llama3Model
74+
75+
if "1B" in MODEL_FILE:
76+
from llms_from_scratch.llama3 import LLAMA32_CONFIG_1B as LLAMA32_CONFIG
77+
elif "3B" in MODEL_FILE:
78+
from llms_from_scratch.llama3 import LLAMA32_CONFIG_3B as LLAMA32_CONFIG
79+
else:
80+
raise ValueError("Incorrect model file name")
81+
82+
LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH
83+
84+
model = Llama3Model(LLAMA32_CONFIG)
85+
model.load_state_dict(torch.load(MODEL_FILE, weights_only=True))
86+
87+
device = (
88+
torch.device("cuda") if torch.cuda.is_available() else
89+
torch.device("mps") if torch.backends.mps.is_available() else
90+
torch.device("cpu")
91+
)
92+
model.to(device)
93+
```
94+
95+
&nbsp;
96+
##### 4) Initialize tokenizer
97+
98+
The following code downloads and initializes the tokenizer:
99+
100+
```python
101+
from llms_from_scratch.llama3 import Llama3Tokenizer, ChatFormat, clean_text
102+
103+
TOKENIZER_FILE = "tokenizer.model"
104+
105+
url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{TOKENIZER_FILE}"
106+
107+
if not os.path.exists(TOKENIZER_FILE):
108+
urllib.request.urlretrieve(url, TOKENIZER_FILE)
109+
print(f"Downloaded to {TOKENIZER_FILE}")
110+
111+
tokenizer = Llama3Tokenizer("tokenizer.model")
112+
113+
if "instruct" in MODEL_FILE:
114+
tokenizer = ChatFormat(tokenizer)
115+
```
116+
117+
&nbsp;
118+
##### 5) Generating text
119+
120+
Lastly, we can generate text via the following code:
121+
122+
```python
123+
import time
124+
125+
from llms_from_scratch.ch05 import (
126+
generate,
127+
text_to_token_ids,
128+
token_ids_to_text
129+
)
130+
131+
torch.manual_seed(123)
132+
133+
start = time.time()
134+
135+
token_ids = generate(
136+
model=model,
137+
idx=text_to_token_ids(PROMPT, tokenizer).to(device),
138+
max_new_tokens=MAX_NEW_TOKENS,
139+
context_size=LLAMA32_CONFIG["context_length"],
140+
top_k=TOP_K,
141+
temperature=TEMPERATURE
142+
)
143+
144+
print(f"Time: {time.time() - start:.2f} sec")
145+
146+
if torch.cuda.is_available():
147+
max_mem_bytes = torch.cuda.max_memory_allocated()
148+
max_mem_gb = max_mem_bytes / (1024 ** 3)
149+
print(f"Max memory allocated: {max_mem_gb:.2f} GB")
150+
151+
output_text = token_ids_to_text(token_ids, tokenizer)
152+
153+
if "instruct" in MODEL_FILE:
154+
output_text = clean_text(output_text)
155+
156+
print("\n\nOutput text:\n\n", output_text)
157+
```
158+
159+
When using the Llama 3.2 1B Instruct model, the output should look similar to the one shown below:
160+
161+
```
162+
Time: 4.12 sec
163+
Max memory allocated: 2.91 GB
164+
165+
166+
Output text:
167+
168+
Llamas are herbivores, which means they primarily eat plants. Their diet consists mainly of:
169+
170+
1. Grasses: Llamas love to graze on various types of grasses, including tall grasses and grassy meadows.
171+
2. Hay: Llamas also eat hay, which is a dry, compressed form of grass or other plants.
172+
3. Alfalfa: Alfalfa is a legume that is commonly used as a hay substitute in llama feed.
173+
4. Other plants: Llamas will also eat other plants, such as clover, dandelions, and wild grasses.
174+
175+
It's worth noting that the specific diet of llamas can vary depending on factors such as the breed,
176+
```
177+
178+
&nbsp;
179+
**Pro tip**
180+
181+
For up to a 4× speed-up, replace
182+
183+
```python
184+
model.to(device)
185+
```
186+
187+
with
188+
189+
```python
190+
model = torch.compile(model)
191+
model.to(device)
192+
```
193+
194+
Note: the speed-up takes effect after the first `generate` call.
195+

pkg/llms_from_scratch/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,13 @@ from llms_from_scratch.ch07 import (
109109
from llms_from_scratch.appendix_a import NeuralNetwork, ToyDataset
110110

111111
from llms_from_scratch.appendix_d import find_highest_gradient, train_model
112+
113+
from llms_from_scratch.llama3 import (
114+
Llama3Model,
115+
Llama3Tokenizer,
116+
ChatFormat,
117+
clean_text
118+
)
112119
```
113120

121+
(For the `llms_from_scratch.llama3` usage information, please see [this bonus section](../../ch05/07_gpt_to_llama/README.md).

0 commit comments

Comments
 (0)